12 Commits e91ffdad44 ... 3629518758

Author SHA1 Message Date
  Tharmetharan Balendran 3629518758 Merge branch 'end-to-end' of https://gogs.infcof.com/4ycp/simulations into modulation_schemes 4 years ago
  Tharmetharan Balendran 2d60347040 resolved conflict 4 years ago
  Tharmetharan Balendran ab42d5f6a3 Refactoring/Clean-up 4 years ago
  Tharmetharan Balendran b8ff3d2e85 minor additions 4 years ago
  Tharmetharan Balendran 0b990b8b42 refactoring 4 years ago
  Tharmetharan Balendran 0101829faf cleaned up code 4 years ago
  Tharmetharan Balendran 8224f88367 variable length training added 4 years ago
  Tharmetharan Balendran 16045ea3e0 ber vs length plot 4 years ago
  Tharmetharan Balendran 9f39857933 iterative model training 4 years ago
  Tharmetharan Balendran a75063d665 iterative learning 4 years ago
  Tharmetharan Balendran 7e2c83ee4c bit-symbol mapping attemts 4 years ago
  Tharmetharan Balendran 80a5361d42 refactoring and lstm layers added 5 years ago
6 changed files with 981 additions and 197 deletions
  1. 4 1
      .gitignore
  2. 190 0
      models/custom_layers.py
  3. 393 194
      models/end_to_end.py
  4. 237 0
      models/new_model.py
  5. 90 0
      models/plots.py
  6. 67 2
      tests/misc_test.py

+ 4 - 1
.gitignore

@@ -5,4 +5,7 @@ __pycache__
 
 
 # Environments
 # Environments
 venv/
 venv/
-tests/local_test.py
+tests/local_test.py
+
+# Model export files
+models/exports

+ 190 - 0
models/custom_layers.py

@@ -0,0 +1,190 @@
+from tensorflow.keras import layers
+import tensorflow as tf
+import math
+import numpy as np
+import itertools
+
+
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality, messages_per_block):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.all_syms)
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.samples_per_symbol = samples_per_symbol
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1):
+        """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__()
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 fiber_length_stddev=0,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param sig_avg: Average signal amplitude
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
+        self.lpf_cutoff = lpf_cutoff
+        self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
+
+        self.noise_layer = layers.GaussianNoise(self.rx_stddev)
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
+        self.flatten_layer = layers.Flatten()
+
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_val)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out

+ 393 - 194
models/end_to_end.py

@@ -1,143 +1,14 @@
+import json
 import math
 import math
-
+import os
+from datetime import datetime as dt
 import tensorflow as tf
 import tensorflow as tf
 import numpy as np
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
-
-
-class ExtractCentralMessage(layers.Layer):
-    def __init__(self, messages_per_block, samples_per_symbol):
-        """
-        A keras layer that extracts the central message(symbol) in a block.
-
-        :param messages_per_block: Total number of messages in transmission block
-        :param samples_per_symbol: Number of samples per transmitted symbol
-        """
-        super(ExtractCentralMessage, self).__init__()
-
-        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
-        i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
-        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
-        temp_w[begin:end, :] = i
-
-        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(inputs, self.w)
-
-
-class AwgnChannel(layers.Layer):
-    def __init__(self, rx_stddev=0.1):
-        """
-        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
-        being applied every time the call function is called.
-
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        """
-        super(AwgnChannel, self).__init__()
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-
-    def call(self, inputs, **kwargs):
-        return self.noise_layer.call(inputs, training=True)
-
-
-class DigitizationLayer(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 lpf_cutoff=32e9,
-                 q_stddev=0.1):
-        """
-        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
-        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(DigitizationLayer, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(q_stddev)
-        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
-        temp = np.ones(freq.shape)
-
-        for idx, val in np.ndenumerate(freq):
-            if np.abs(val) > lpf_cutoff:
-                temp[idx] = 0
-
-        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
-
-    def call(self, inputs, **kwargs):
-        complex_in = tf.cast(inputs, dtype=tf.complex64)
-        val_f = tf.signal.fft(complex_in)
-        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
-        filtered_t = tf.signal.ifft(filtered_f)
-        real_t = tf.cast(filtered_t, dtype=tf.float32)
-        noisy = self.noise_layer.call(real_t, training=True)
-        return noisy
-
-
-class OpticalChannel(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 dispersion_factor,
-                 fiber_length,
-                 lpf_cutoff=32e9,
-                 rx_stddev=0.01,
-                 q_stddev=0.01):
-        """
-        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
-        ADC/DAC as well as additive white gaussian noise in optical communication channels.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param dispersion_factor: Dispersion factor in s^2/km
-        :param fiber_length: Length of fiber to model in km
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(OpticalChannel, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
-        self.flatten_layer = layers.Flatten()
-
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
-
-    def call(self, inputs, **kwargs):
-        # DAC LPF and noise
-        dac_out = self.digitization_layer(inputs)
-
-        # Chromatic Dispersion
-        complex_val = tf.cast(dac_out, dtype=tf.complex128)
-        val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
-        disp_t = tf.signal.ifft(disp_f)
-
-        # Squared-Law Detection
-        pd_out = tf.square(tf.abs(disp_t))
-
-        # Casting back to floatx
-        real_val = tf.cast(pd_out, dtype=tf.float32)
-
-        # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(real_val, training=True)
-
-        # ADC LPF and noise
-        adc_out = self.digitization_layer(rx_signal)
-
-        return adc_out
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, SymbolsToBits
 
 
 
 
 class EndToEndAutoencoder(tf.keras.Model):
 class EndToEndAutoencoder(tf.keras.Model):
@@ -145,7 +16,8 @@ class EndToEndAutoencoder(tf.keras.Model):
                  cardinality,
                  cardinality,
                  samples_per_symbol,
                  samples_per_symbol,
                  messages_per_block,
                  messages_per_block,
-                 channel):
+                 channel,
+                 custom_loss_fn=False):
         """
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -160,38 +32,158 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         # Labelled M in paper
         # Labelled M in paper
         self.cardinality = cardinality
         self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
         # Labelled n in paper
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
         self.samples_per_symbol = samples_per_symbol
-        # Labelled N in paper
+
+        # Labelled N in paper - conditional +=1 to ensure odd value
         if messages_per_block % 2 == 0:
         if messages_per_block % 2 == 0:
             messages_per_block += 1
             messages_per_block += 1
         self.messages_per_block = messages_per_block
         self.messages_per_block = messages_per_block
+
         # Channel Model Layer
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
         if isinstance(channel, layers.Layer):
             self.channel = tf.keras.Sequential([
             self.channel = tf.keras.Sequential([
                 layers.Flatten(),
                 layers.Flatten(),
                 channel,
                 channel,
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
-            ])
+            ], name="channel_model")
         else:
         else:
-            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+            raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
+
+        # Boolean identifying if bit mapping is to be learnt
+        self.custom_loss_fn = custom_loss_fn
+
+        # other parameters/metrics
+        self.symbol_error_rate = None
+        self.bit_error_rate = None
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
+
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
 
 
         # Encoding Neural Network
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
         self.encoder = tf.keras.Sequential([
             layers.Input(shape=(self.messages_per_block, self.cardinality)),
             layers.Input(shape=(self.messages_per_block, self.cardinality)),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.samples_per_symbol),
-            layers.ReLU(max_value=1.0)
-        ])
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+            # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+            # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
+        ], name="encoding_model")
 
 
         # Decoding Neural Network
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
         self.decoder = tf.keras.Sequential([
-            layers.Dense(self.samples_per_symbol, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
             layers.Dense(self.cardinality, activation='softmax')
             layers.Dense(self.cardinality, activation='softmax')
-        ])
+        ], name="decoding_model")
+
+    def save_end_to_end(self):
+        # extract all params and save
+
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+        dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, _ = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def decode_stream(self, x):
+        _, dec_weights = self.extract_weights()
+
+        for i in range(len(dec_weights) // 2):
+            x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
+
+            if i == len(dec_weights) // 2 - 1:
+                x = tf.keras.activations.softmax(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+
+        a = 1
+
+        return symbol_cost + a * bit_cost
 
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
         """
@@ -201,15 +193,17 @@ class EndToEndAutoencoder(tf.keras.Model):
         consecutively to model ISI. The central message in a block is returned as the label for training.
         consecutively to model ISI. The central message in a block is returned as the label for training.
         :param return_vals: If true, the raw decimal values of the input sequence will be returned
         :param return_vals: If true, the raw decimal values of the input sequence will be returned
         """
         """
-        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
 
         cat = [np.arange(self.cardinality)]
         cat = [np.arange(self.cardinality)]
         enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
         enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
 
 
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
         out = enc.fit_transform(rand_int)
         out = enc.fit_transform(rand_int)
-        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
 
         if return_vals:
         if return_vals:
             out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
             out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
@@ -217,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return out_arr, out_arr[:, mid_idx, :]
         return out_arr, out_arr[:, mid_idx, :]
 
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
         """
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
 
@@ -226,37 +220,110 @@ class EndToEndAutoencoder(tf.keras.Model):
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         """
         """
-        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
-        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
 
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
 
+        if self.custom_loss_fn:
+            loss_fn = self.cost
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
+
+        callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
+
         self.compile(optimizer=opt,
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=loss_fn,
                      metrics=['accuracy'],
                      metrics=['accuracy'],
                      loss_weights=None,
                      loss_weights=None,
                      weighted_metrics=None,
                      weighted_metrics=None,
                      run_eagerly=False
                      run_eagerly=False
                      )
                      )
 
 
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=1,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           callbacks=[callback],
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        if len(history.history['loss']) == epochs:
+            print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
+                  "Setting a higher epoch number and retraining is recommended")
+
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+
+        y_out = self.call(X_test)
+
+        y_pred = tf.argmax(y_out, axis=1)
+        y_true = tf.argmax(y_test, axis=1)
+
+        self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+
+        bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+        bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
+
+        self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+
+        if (length_plot):
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+
+                y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            if plt_show:
+                plt.show()
+
+        print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(self.bit_error_rate))
+
+        pass
 
 
     def view_encoder(self):
     def view_encoder(self):
         '''
         '''
-        A method that views the learnt encoder for each distint message. This is displayed as a plot with  asubplot for
-        each image.
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
         '''
         '''
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
         # Generate inputs for encoder
         # Generate inputs for encoder
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
 
-        mid_idx = int((self.messages_per_block-1)/2)
-
         idx = 0
         idx = 0
         for msg in messages:
         for msg in messages:
             msg[mid_idx, idx] = 1
             msg[mid_idx, idx] = 1
@@ -268,23 +335,23 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         # Compute subplot grid layout
         # Compute subplot grid layout
         i = 0
         i = 0
-        while 2**i < self.cardinality**0.5:
+        while 2 ** i < self.cardinality ** 0.5:
             i += 1
             i += 1
 
 
-        num_x = int(2**i)
+        num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
         num_y = int(self.cardinality / num_x)
 
 
         # Plot all symbols
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
 
 
         t = np.arange(self.samples_per_symbol)
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
         if isinstance(self.channel.layers[1], OpticalChannel):
-            t = t/self.channel.layers[1].fs
+            t = t / self.channel.layers[1].fs
 
 
         sym_idx = 0
         sym_idx = 0
         for y in range(num_y):
         for y in range(num_y):
             for x in range(num_x):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 sym_idx += 1
                 sym_idx += 1
 
 
@@ -308,33 +375,40 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Encode and flatten the messages
         # Encode and flatten the messages
         enc = self.encoder(inp)
         enc = self.encoder(inp)
         flat_enc = layers.Flatten()(enc)
         flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
 
 
         # Instantiate LPF layer
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
-                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
-                                q_stddev=0)
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
+                                sig_avg=0)
 
 
         # Apply LPF
         # Apply LPF
         lpf_out = lpf(flat_enc)
         lpf_out = lpf(flat_enc)
 
 
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
         # Time axis
         # Time axis
-        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
         if isinstance(self.channel.layers[1], OpticalChannel):
             t = t / self.channel.layers[1].fs
             t = t / self.channel.layers[1].fs
 
 
         # Plot the concatenated symbols before and after LPF
         # Plot the concatenated symbols before and after LPF
-        plt.figure(figsize=(2*self.messages_per_block, 6))
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
 
 
         for i in range(1, self.messages_per_block):
         for i in range(1, self.messages_per_block):
-            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
 
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
         plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
         plt.ylim((0, 1))
         plt.ylim((0, 1))
         plt.xlim((t.min(), t.max()))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.title(str(val[0, :, 0]))
         plt.show()
         plt.show()
-        pass
 
 
     def call(self, inputs, training=None, mask=None):
     def call(self, inputs, training=None, mask=None):
         tx = self.encoder(inputs)
         tx = self.encoder(inputs)
@@ -343,27 +417,152 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
         return outputs
 
 
 
 
-if __name__ == '__main__':
-
-    SAMPLING_FREQUENCY = 336e9
-    CARDINALITY = 32
-    SAMPLES_PER_SYMBOL = 24
-    MESSAGES_PER_BLOCK = 9
-    DISPERSION_FACTOR = -21.7 * 1e-24
-    FIBER_LENGTH = 50
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
+
+if __name__ == 'asd':
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    lengths = np.linspace(40, 100, 7)
+    ber = []
+    for len_ in lengths:
+        optical_channel = OpticalChannel(fs=params["fs"],
+                                         num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                         dispersion_factor=params["dispersion_factor"],
+                                         fiber_length=len_,
+                                         fiber_length_stddev=params["fiber_length_stddev"],
+                                         lpf_cutoff=params["lpf_cutoff"],
+                                         rx_stddev=0,
+                                         sig_avg=0,
+                                         enob=params["enob"])
+
+        ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                       samples_per_symbol=params["samples_per_symbol"],
+                                       messages_per_block=params["messages_per_block"],
+                                       channel=optical_channel,
+                                       custom_loss_fn=params["custom_loss_fn"])
+        ae_model.train(num_of_blocks=1e5)
+        ae_model.test()
+        ber.append(ae_model.bit_error_rate)
+
+    plt.plot(lengths, ber)
+    plt.title("Bit Error Rate at different trained lengths")
+    plt.yscale('log')
+    plt.xlabel("Fiber Length / km")
+    plt.ylabel("Bit Error Rate")
+    plt.show()
+    pass
 
 
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
+if __name__ == '__main__':
 
 
-    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                   messages_per_block=MESSAGES_PER_BLOCK,
-                                   channel=optical_channel)
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    force_training = False
+
+    model_save_name = "20210317-124015"
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    if os.path.isfile(param_file_path) and not force_training:
+        ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+        ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    else:
+        ae_model.train(num_of_blocks=1e4)
+        ae_model.save_end_to_end()
 
 
-    ae_model.train(num_of_blocks=1e6, batch_size=100)
     ae_model.view_encoder()
     ae_model.view_encoder()
-    ae_model.view_sample_block()
+    ae_model.test()
+
+    # cat = [np.arange(32)]
+    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    #
+    # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
+    # inp_oh = enc.fit_transform(inp)
+    #
+    # out = ae_model(inp_oh.reshape(1, 9, 32))
+    #
+    # a = out.numpy()
+    #
+    # plt.plot(a)
+    # plt.show()
 
 
     pass
     pass

+ 237 - 0
models/new_model.py

@@ -0,0 +1,237 @@
+import tensorflow as tf
+from tensorflow.keras import losses
+from models.custom_layers import OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
+import numpy as np
+import math
+
+from matplotlib import pyplot as plt
+
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             custom_loss_fn=False)
+
+        self.bit_error_rate = []
+        self.symbol_error_rate = []
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+        X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.MeanSquaredError(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        for i in range(int(iters)):
+            print("Loop {}/{}".format(i, iters))
+
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
+
+if __name__ == '__main__':
+
+    distances = [50]
+    ser = []
+    ber = []
+
+    baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
+    bit_rate = math.log(CARDINALITY, 2) * baud_rate
+    snr = None
+
+    for d in distances:
+
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
+
+        model = BitMappingModel(cardinality=CARDINALITY,
+                                samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                messages_per_block=MESSAGES_PER_BLOCK,
+                                channel=optical_channel)
+
+        if snr is None:
+            snr = model.e2e_model.snr
+        elif snr != model.e2e_model.snr:
+            print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
+
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+
+        model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
+
+        model.e2e_model.test(length_plot=True)
+
+        ber.append(model.bit_error_rate[-1])
+        ser.append(model.symbol_error_rate[-1])
+
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        custom_loss_fn=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
+        plt.legend()
+        plt.show()
+        # model.summary()
+
+    # plt.plot(ber, label='BER')
+    # plt.plot(ser, label='SER')
+    # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    # plt.legend(ber)

+ 90 - 0
models/plots.py

@@ -0,0 +1,90 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.custom_layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+
+def plot_e2e_spectrum(model_name=None, num_samples=10000):
+    '''
+    Plot frequency spectrum of the output signal at the encoder
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    @param num_samples: The number of symbols to simulate when computing the spectrum.
+    '''
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["cardinality"] * num_samples,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+
+def plot_e2e_encoded_output(model_name=None):
+    '''
+    Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
+    The distorted DD received signal is also plotted.
+
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    '''
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+
+if __name__ == '__main__':
+    plot_e2e_spectrum()
+    plot_e2e_encoded_output()

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import misc
 import numpy as np
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 
 
 def test_bit_matrix_one_hot():
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass