12 Commits e91ffdad44 ... 3629518758

Autor SHA1 Mensagem Data
  Tharmetharan Balendran 3629518758 Merge branch 'end-to-end' of https://gogs.infcof.com/4ycp/simulations into modulation_schemes 4 anos atrás
  Tharmetharan Balendran 2d60347040 resolved conflict 4 anos atrás
  Tharmetharan Balendran ab42d5f6a3 Refactoring/Clean-up 4 anos atrás
  Tharmetharan Balendran b8ff3d2e85 minor additions 4 anos atrás
  Tharmetharan Balendran 0b990b8b42 refactoring 4 anos atrás
  Tharmetharan Balendran 0101829faf cleaned up code 4 anos atrás
  Tharmetharan Balendran 8224f88367 variable length training added 4 anos atrás
  Tharmetharan Balendran 16045ea3e0 ber vs length plot 4 anos atrás
  Tharmetharan Balendran 9f39857933 iterative model training 4 anos atrás
  Tharmetharan Balendran a75063d665 iterative learning 4 anos atrás
  Tharmetharan Balendran 7e2c83ee4c bit-symbol mapping attemts 4 anos atrás
  Tharmetharan Balendran 80a5361d42 refactoring and lstm layers added 5 anos atrás
6 arquivos alterados com 981 adições e 197 exclusões
  1. 4 1
      .gitignore
  2. 190 0
      models/custom_layers.py
  3. 393 194
      models/end_to_end.py
  4. 237 0
      models/new_model.py
  5. 90 0
      models/plots.py
  6. 67 2
      tests/misc_test.py

+ 4 - 1
.gitignore

@@ -5,4 +5,7 @@ __pycache__
 
 # Environments
 venv/
-tests/local_test.py
+tests/local_test.py
+
+# Model export files
+models/exports

+ 190 - 0
models/custom_layers.py

@@ -0,0 +1,190 @@
+from tensorflow.keras import layers
+import tensorflow as tf
+import math
+import numpy as np
+import itertools
+
+
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality, messages_per_block):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.all_syms)
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.samples_per_symbol = samples_per_symbol
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1):
+        """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__()
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 fiber_length_stddev=0,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param sig_avg: Average signal amplitude
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
+        self.lpf_cutoff = lpf_cutoff
+        self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
+
+        self.noise_layer = layers.GaussianNoise(self.rx_stddev)
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
+        self.flatten_layer = layers.Flatten()
+
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_val)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out

+ 393 - 194
models/end_to_end.py

@@ -1,143 +1,14 @@
+import json
 import math
-
+import os
+from datetime import datetime as dt
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
-
-
-class ExtractCentralMessage(layers.Layer):
-    def __init__(self, messages_per_block, samples_per_symbol):
-        """
-        A keras layer that extracts the central message(symbol) in a block.
-
-        :param messages_per_block: Total number of messages in transmission block
-        :param samples_per_symbol: Number of samples per transmitted symbol
-        """
-        super(ExtractCentralMessage, self).__init__()
-
-        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
-        i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
-        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
-        temp_w[begin:end, :] = i
-
-        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(inputs, self.w)
-
-
-class AwgnChannel(layers.Layer):
-    def __init__(self, rx_stddev=0.1):
-        """
-        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
-        being applied every time the call function is called.
-
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        """
-        super(AwgnChannel, self).__init__()
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-
-    def call(self, inputs, **kwargs):
-        return self.noise_layer.call(inputs, training=True)
-
-
-class DigitizationLayer(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 lpf_cutoff=32e9,
-                 q_stddev=0.1):
-        """
-        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
-        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(DigitizationLayer, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(q_stddev)
-        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
-        temp = np.ones(freq.shape)
-
-        for idx, val in np.ndenumerate(freq):
-            if np.abs(val) > lpf_cutoff:
-                temp[idx] = 0
-
-        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
-
-    def call(self, inputs, **kwargs):
-        complex_in = tf.cast(inputs, dtype=tf.complex64)
-        val_f = tf.signal.fft(complex_in)
-        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
-        filtered_t = tf.signal.ifft(filtered_f)
-        real_t = tf.cast(filtered_t, dtype=tf.float32)
-        noisy = self.noise_layer.call(real_t, training=True)
-        return noisy
-
-
-class OpticalChannel(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 dispersion_factor,
-                 fiber_length,
-                 lpf_cutoff=32e9,
-                 rx_stddev=0.01,
-                 q_stddev=0.01):
-        """
-        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
-        ADC/DAC as well as additive white gaussian noise in optical communication channels.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param dispersion_factor: Dispersion factor in s^2/km
-        :param fiber_length: Length of fiber to model in km
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(OpticalChannel, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
-        self.flatten_layer = layers.Flatten()
-
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
-
-    def call(self, inputs, **kwargs):
-        # DAC LPF and noise
-        dac_out = self.digitization_layer(inputs)
-
-        # Chromatic Dispersion
-        complex_val = tf.cast(dac_out, dtype=tf.complex128)
-        val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
-        disp_t = tf.signal.ifft(disp_f)
-
-        # Squared-Law Detection
-        pd_out = tf.square(tf.abs(disp_t))
-
-        # Casting back to floatx
-        real_val = tf.cast(pd_out, dtype=tf.float32)
-
-        # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(real_val, training=True)
-
-        # ADC LPF and noise
-        adc_out = self.digitization_layer(rx_signal)
-
-        return adc_out
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, SymbolsToBits
 
 
 class EndToEndAutoencoder(tf.keras.Model):
@@ -145,7 +16,8 @@ class EndToEndAutoencoder(tf.keras.Model):
                  cardinality,
                  samples_per_symbol,
                  messages_per_block,
-                 channel):
+                 channel,
+                 custom_loss_fn=False):
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -160,38 +32,158 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Labelled M in paper
         self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
-        # Labelled N in paper
+
+        # Labelled N in paper - conditional +=1 to ensure odd value
         if messages_per_block % 2 == 0:
             messages_per_block += 1
         self.messages_per_block = messages_per_block
+
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
             self.channel = tf.keras.Sequential([
                 layers.Flatten(),
                 channel,
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
-            ])
+            ], name="channel_model")
         else:
-            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+            raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
+
+        # Boolean identifying if bit mapping is to be learnt
+        self.custom_loss_fn = custom_loss_fn
+
+        # other parameters/metrics
+        self.symbol_error_rate = None
+        self.bit_error_rate = None
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
+
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
 
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
             layers.Input(shape=(self.messages_per_block, self.cardinality)),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.samples_per_symbol),
-            layers.ReLU(max_value=1.0)
-        ])
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+            # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+            # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
+        ], name="encoding_model")
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
-            layers.Dense(self.samples_per_symbol, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
             layers.Dense(self.cardinality, activation='softmax')
-        ])
+        ], name="decoding_model")
+
+    def save_end_to_end(self):
+        # extract all params and save
+
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+        dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, _ = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def decode_stream(self, x):
+        _, dec_weights = self.extract_weights()
+
+        for i in range(len(dec_weights) // 2):
+            x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
+
+            if i == len(dec_weights) // 2 - 1:
+                x = tf.keras.activations.softmax(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+
+        a = 1
+
+        return symbol_cost + a * bit_cost
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
@@ -201,15 +193,17 @@ class EndToEndAutoencoder(tf.keras.Model):
         consecutively to model ISI. The central message in a block is returned as the label for training.
         :param return_vals: If true, the raw decimal values of the input sequence will be returned
         """
-        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
         cat = [np.arange(self.cardinality)]
         enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
 
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
         out = enc.fit_transform(rand_int)
-        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
         if return_vals:
             out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
@@ -217,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
@@ -226,37 +220,110 @@ class EndToEndAutoencoder(tf.keras.Model):
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         """
-        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
-        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
+        if self.custom_loss_fn:
+            loss_fn = self.cost
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
+
+        callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
+
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=loss_fn,
                      metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
                      run_eagerly=False
                      )
 
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=1,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           callbacks=[callback],
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        if len(history.history['loss']) == epochs:
+            print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
+                  "Setting a higher epoch number and retraining is recommended")
+
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+
+        y_out = self.call(X_test)
+
+        y_pred = tf.argmax(y_out, axis=1)
+        y_true = tf.argmax(y_test, axis=1)
+
+        self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+
+        bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+        bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
+
+        self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+
+        if (length_plot):
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+
+                y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            if plt_show:
+                plt.show()
+
+        print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(self.bit_error_rate))
+
+        pass
 
     def view_encoder(self):
         '''
-        A method that views the learnt encoder for each distint message. This is displayed as a plot with  asubplot for
-        each image.
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
         '''
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
         # Generate inputs for encoder
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
-        mid_idx = int((self.messages_per_block-1)/2)
-
         idx = 0
         for msg in messages:
             msg[mid_idx, idx] = 1
@@ -268,23 +335,23 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Compute subplot grid layout
         i = 0
-        while 2**i < self.cardinality**0.5:
+        while 2 ** i < self.cardinality ** 0.5:
             i += 1
 
-        num_x = int(2**i)
+        num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
 
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
 
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
-            t = t/self.channel.layers[1].fs
+            t = t / self.channel.layers[1].fs
 
         sym_idx = 0
         for y in range(num_y):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 sym_idx += 1
 
@@ -308,33 +375,40 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Encode and flatten the messages
         enc = self.encoder(inp)
         flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
 
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
-                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
-                                q_stddev=0)
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
+                                sig_avg=0)
 
         # Apply LPF
         lpf_out = lpf(flat_enc)
 
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
         # Time axis
-        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
             t = t / self.channel.layers[1].fs
 
         # Plot the concatenated symbols before and after LPF
-        plt.figure(figsize=(2*self.messages_per_block, 6))
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
 
         for i in range(1, self.messages_per_block):
-            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
         plt.ylim((0, 1))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.show()
-        pass
 
     def call(self, inputs, training=None, mask=None):
         tx = self.encoder(inputs)
@@ -343,27 +417,152 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
 
 
-if __name__ == '__main__':
-
-    SAMPLING_FREQUENCY = 336e9
-    CARDINALITY = 32
-    SAMPLES_PER_SYMBOL = 24
-    MESSAGES_PER_BLOCK = 9
-    DISPERSION_FACTOR = -21.7 * 1e-24
-    FIBER_LENGTH = 50
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
+
+if __name__ == 'asd':
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    lengths = np.linspace(40, 100, 7)
+    ber = []
+    for len_ in lengths:
+        optical_channel = OpticalChannel(fs=params["fs"],
+                                         num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                         dispersion_factor=params["dispersion_factor"],
+                                         fiber_length=len_,
+                                         fiber_length_stddev=params["fiber_length_stddev"],
+                                         lpf_cutoff=params["lpf_cutoff"],
+                                         rx_stddev=0,
+                                         sig_avg=0,
+                                         enob=params["enob"])
+
+        ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                       samples_per_symbol=params["samples_per_symbol"],
+                                       messages_per_block=params["messages_per_block"],
+                                       channel=optical_channel,
+                                       custom_loss_fn=params["custom_loss_fn"])
+        ae_model.train(num_of_blocks=1e5)
+        ae_model.test()
+        ber.append(ae_model.bit_error_rate)
+
+    plt.plot(lengths, ber)
+    plt.title("Bit Error Rate at different trained lengths")
+    plt.yscale('log')
+    plt.xlabel("Fiber Length / km")
+    plt.ylabel("Bit Error Rate")
+    plt.show()
+    pass
 
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
+if __name__ == '__main__':
 
-    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                   messages_per_block=MESSAGES_PER_BLOCK,
-                                   channel=optical_channel)
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    force_training = False
+
+    model_save_name = "20210317-124015"
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    if os.path.isfile(param_file_path) and not force_training:
+        ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+        ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    else:
+        ae_model.train(num_of_blocks=1e4)
+        ae_model.save_end_to_end()
 
-    ae_model.train(num_of_blocks=1e6, batch_size=100)
     ae_model.view_encoder()
-    ae_model.view_sample_block()
+    ae_model.test()
+
+    # cat = [np.arange(32)]
+    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    #
+    # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
+    # inp_oh = enc.fit_transform(inp)
+    #
+    # out = ae_model(inp_oh.reshape(1, 9, 32))
+    #
+    # a = out.numpy()
+    #
+    # plt.plot(a)
+    # plt.show()
 
     pass

+ 237 - 0
models/new_model.py

@@ -0,0 +1,237 @@
+import tensorflow as tf
+from tensorflow.keras import losses
+from models.custom_layers import OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
+import numpy as np
+import math
+
+from matplotlib import pyplot as plt
+
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             custom_loss_fn=False)
+
+        self.bit_error_rate = []
+        self.symbol_error_rate = []
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+        X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.MeanSquaredError(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        for i in range(int(iters)):
+            print("Loop {}/{}".format(i, iters))
+
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
+
+if __name__ == '__main__':
+
+    distances = [50]
+    ser = []
+    ber = []
+
+    baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
+    bit_rate = math.log(CARDINALITY, 2) * baud_rate
+    snr = None
+
+    for d in distances:
+
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
+
+        model = BitMappingModel(cardinality=CARDINALITY,
+                                samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                messages_per_block=MESSAGES_PER_BLOCK,
+                                channel=optical_channel)
+
+        if snr is None:
+            snr = model.e2e_model.snr
+        elif snr != model.e2e_model.snr:
+            print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
+
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+
+        model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
+
+        model.e2e_model.test(length_plot=True)
+
+        ber.append(model.bit_error_rate[-1])
+        ser.append(model.symbol_error_rate[-1])
+
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        custom_loss_fn=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
+        plt.legend()
+        plt.show()
+        # model.summary()
+
+    # plt.plot(ber, label='BER')
+    # plt.plot(ser, label='SER')
+    # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    # plt.legend(ber)

+ 90 - 0
models/plots.py

@@ -0,0 +1,90 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.custom_layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+
+def plot_e2e_spectrum(model_name=None, num_samples=10000):
+    '''
+    Plot frequency spectrum of the output signal at the encoder
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    @param num_samples: The number of symbols to simulate when computing the spectrum.
+    '''
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["cardinality"] * num_samples,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+
+def plot_e2e_encoded_output(model_name=None):
+    '''
+    Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
+    The distorted DD received signal is also plotted.
+
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    '''
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+
+if __name__ == '__main__':
+    plot_e2e_spectrum()
+    plot_e2e_encoded_output()

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass