Selaa lähdekoodia

Merge branch 'end-to-end' into Standalone_NN_devel

# Conflicts:
#	.gitignore
#	models/end_to_end.py
#	models/layers.py
#	models/new_model.py
Min 4 vuotta sitten
vanhempi
commit
eb84764f2f
4 muutettua tiedostoa jossa 452 lisäystä ja 160 poistoa
  1. 254 126
      models/end_to_end.py
  2. 27 16
      models/layers.py
  3. 80 18
      models/new_model.py
  4. 91 0
      models/plots.py

+ 254 - 126
models/end_to_end.py

@@ -1,14 +1,14 @@
-import itertools
+import json
 import math
-
+import os
+from datetime import datetime as dt
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
-
-from layers import ExtractCentralMessage, BitsToSymbols, SymbolsToBits, OpticalChannel, DigitizationLayer
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
 
 
 class EndToEndAutoencoder(tf.keras.Model):
@@ -17,7 +17,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  samples_per_symbol,
                  messages_per_block,
                  channel,
-                 bit_mapping=False):
+                 custom_loss_fn=False):
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -37,7 +37,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
 
-        # Labelled N in paper
+        # Labelled N in paper - conditional +=1 to ensure odd value
         if messages_per_block % 2 == 0:
             messages_per_block += 1
         self.messages_per_block = messages_per_block
@@ -50,71 +50,128 @@ class EndToEndAutoencoder(tf.keras.Model):
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
             ], name="channel_model")
         else:
-            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+            raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
 
         # Boolean identifying if bit mapping is to be learnt
-        self.bit_mapping = bit_mapping
+        self.custom_loss_fn = custom_loss_fn
 
         # other parameters/metrics
         self.symbol_error_rate = None
         self.bit_error_rate = None
-        self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
 
         # Model Hyper-parameters
         leaky_relu_alpha = 0
         relu_clip_val = 1.0
 
-        # Layer configuration for the case when bit mapping is to be learnt
-        if self.bit_mapping:
-            encoding_layers = [
-                layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
-                BitsToSymbols(self.cardinality),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
-            ]
-            decoding_layers = [
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                # layers.Dense(2 * self.cardinality),
-                # layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid')
-            ]
-
-        # layer configuration for the case when only symbol mapping is to be learnt
-        else:
-            encoding_layers = [
-                layers.Input(shape=(self.messages_per_block, self.cardinality)),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
-            ]
-            decoding_layers = [
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                layers.Dense(self.cardinality, activation='softmax')
-            ]
-
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            *encoding_layers
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+            # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+            # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
         ], name="encoding_model")
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
-            *decoding_layers
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(self.cardinality, activation='softmax')
         ], name="decoding_model")
 
+    def save_end_to_end(self):
+        # extract all params and save
+
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+        dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, dec_weights = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+
+        a = 1
+
+        return symbol_cost + a * bit_cost
+
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
         A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
@@ -129,26 +186,15 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         mid_idx = int((self.messages_per_block - 1) / 2)
 
-        if self.bit_mapping:
-            rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
-
-            out = rand_int
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
-            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+        out = enc.fit_transform(rand_int)
 
-            if return_vals:
-                return out_arr, out_arr, out_arr[:, mid_idx, :]
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
-        else:
-            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
-
-            out = enc.fit_transform(rand_int)
-
-            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
-
-            if return_vals:
-                out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
-                return out_val, out_arr, out_arr[:, mid_idx, :]
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
 
         return out_arr, out_arr[:, mid_idx, :]
 
@@ -166,18 +212,8 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
-        # TODO: Investigate different optimizers (with different learning rates and other parameters)
-        # SGD
-        # RMSprop
-        # Adam
-        # Adadelta
-        # Adagrad
-        # Adamax
-        # Nadam
-        # Ftrl
-
-        if self.bit_mapping:
-            loss_fn = losses.BinaryCrossentropy()
+        if self.custom_loss_fn:
+            loss_fn = self.cost
         else:
             loss_fn = losses.CategoricalCrossentropy()
 
@@ -189,7 +225,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      )
 
-        history = self.fit(x=X_train,
+        self.fit(x=X_train,
                  y=y_train,
                  batch_size=batch_size,
                  epochs=epochs,
@@ -197,7 +233,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  validation_data=(X_test, y_test)
                  )
 
-    def test(self, num_of_blocks=1e4):
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
 
         y_out = self.call(X_test)
@@ -207,13 +243,51 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
 
-        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
-
         bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
         bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
 
         self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
 
+        if (length_plot):
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+
+                y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            if plt_show:
+                plt.show()
+
         print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
         print("BIT ERROR RATE: {}".format(self.bit_error_rate))
 
@@ -227,23 +301,13 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         mid_idx = int((self.messages_per_block - 1) / 2)
 
-        if self.bit_mapping:
-            messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
-            lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
-            idx = 0
-            for msg in messages:
-                msg[mid_idx] = lst[idx]
-                idx += 1
-
-        else:
-            # Generate inputs for encoder
-            messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
-
-            idx = 0
-            for msg in messages:
-                msg[mid_idx, idx] = 1
-                idx += 1
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
 
         # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
@@ -301,6 +365,12 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Apply LPF
         lpf_out = lpf(flat_enc)
 
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
         # Time axis
         t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
@@ -319,7 +389,6 @@ class EndToEndAutoencoder(tf.keras.Model):
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.show()
-        pass
 
     def call(self, inputs, training=None, mask=None):
         tx = self.encoder(inputs)
@@ -328,31 +397,90 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
 
 
-SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
-DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 0
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
 
 if __name__ == '__main__':
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
-
-    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                   messages_per_block=MESSAGES_PER_BLOCK,
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    force_training = False
+
+    model_save_name = ""
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
                                    channel=optical_channel,
-                                   bit_mapping=False)
-
-    ae_model.train(num_of_blocks=1e5, epochs=5)
-    ae_model.test()
-    ae_model.view_encoder()
-    ae_model.view_sample_block()
-    # ae_model.summary()
-    ae_model.encoder.summary()
-    ae_model.channel.summary()
-    ae_model.decoder.summary()
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    if os.path.isfile(param_file_path) and not force_training:
+        ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+        ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    else:
+        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.save_end_to_end()
+
     pass

+ 27 - 16
models/layers.py

@@ -28,18 +28,19 @@ class AwgnChannel(layers.Layer):
 
 
 class BitsToSymbols(layers.Layer):
-    def __init__(self, cardinality):
+    def __init__(self, cardinality, messages_per_block):
         super(BitsToSymbols, self).__init__()
 
         self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
 
         n = int(math.log(self.cardinality, 2))
-        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
 
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
         out = tf.one_hot(idx, self.cardinality)
-        return layers.Reshape((9, 32))(out)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
 
 
 class SymbolsToBits(layers.Layer):
@@ -139,6 +140,7 @@ class OpticalChannel(layers.Layer):
                  num_of_samples,
                  dispersion_factor,
                  fiber_length,
+                 fiber_length_stddev=0,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
                  sig_avg=0.5,
@@ -157,23 +159,27 @@ class OpticalChannel(layers.Layer):
         """
         super(OpticalChannel, self).__init__()
 
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
+        self.lpf_cutoff = lpf_cutoff
         self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
 
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
-        self.digitization_layer = DigitizationLayer(
-            fs=fs,
-            num_of_samples=num_of_samples,
-            lpf_cutoff=lpf_cutoff,
-            sig_avg=sig_avg,
-            enob=enob
-        )
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
         self.flatten_layer = layers.Flatten()
 
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(
-            np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64)
-        self.multiplier = tf.math.exp(
-            0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq))
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
 
     def call(self, inputs, **kwargs):
         # DAC LPF and noise
@@ -182,7 +188,12 @@ class OpticalChannel(layers.Layer):
         # Chromatic Dispersion
         complex_val = tf.cast(dac_out, dtype=tf.complex64)
         val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
         disp_t = tf.signal.ifft(disp_f)
 
         # Squared-Law Detection

+ 80 - 18
models/new_model.py

@@ -1,12 +1,14 @@
 import tensorflow as tf
-from tensorflow.keras import losses
-from layers import OpticalChannel, BitsToSymbols, SymbolsToBits
-from end_to_end import EndToEndAutoencoder
+from tensorflow.keras import layers, losses
+from models.custom_layers import ExtractCentralMessage, OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
 import numpy as np
 import math
 
 from matplotlib import pyplot as plt
 
+
 class BitMappingModel(tf.keras.Model):
     def __init__(self,
                  cardinality,
@@ -31,13 +33,13 @@ class BitMappingModel(tf.keras.Model):
                                              samples_per_symbol=self.samples_per_symbol,
                                              messages_per_block=self.messages_per_block,
                                              channel=channel,
-                                             bit_mapping=False)
+                                             custom_loss_fn=False)
 
         self.bit_error_rate = []
         self.symbol_error_rate = []
 
     def call(self, inputs, training=None, mask=None):
-        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
         x2 = self.e2e_model(x1)
         out = SymbolsToBits(self.cardinality)(x2)
         return out
@@ -70,7 +72,7 @@ class BitMappingModel(tf.keras.Model):
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=losses.MeanSquaredError(),
                      metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
@@ -86,7 +88,9 @@ class BitMappingModel(tf.keras.Model):
                  )
 
     def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
-        for _ in range(iters):
+        for i in range(int(iters)):
+            print("Loop {}/{}".format(i, iters))
+
             self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
 
             self.e2e_model.test()
@@ -121,16 +125,33 @@ class BitMappingModel(tf.keras.Model):
             self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
             self.bit_error_rate.append(self.e2e_model.bit_error_rate)
 
+
 SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
 DISPERSION_FACTOR = -21.7 * 1e-24
 FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
 
 if __name__ == '__main__':
 
-    distances = [0, 10, 20, 30, 40, 50, 60]
+    distances = [50]
     ser = []
     ber = []
 
@@ -143,7 +164,9 @@ if __name__ == '__main__':
         optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
                                          num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                          dispersion_factor=DISPERSION_FACTOR,
-                                         fiber_length=d)
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
 
         model = BitMappingModel(cardinality=CARDINALITY,
                                 samples_per_symbol=SAMPLES_PER_SYMBOL,
@@ -155,18 +178,57 @@ if __name__ == '__main__':
         elif snr != model.e2e_model.snr:
             print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
 
-        # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
 
         model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
 
+        model.e2e_model.test(length_plot=True)
+
         ber.append(model.bit_error_rate[-1])
         ser.append(model.symbol_error_rate[-1])
 
-        # plt.plot(model.bit_error_rate, label='BER')
-        # plt.plot(model.symbol_error_rate, label='SER')
-        # plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
-        # plt.legend()
-        # plt.show()
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        custom_loss_fn=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
+        plt.legend()
+        plt.show()
         # model.summary()
 
     # plt.plot(ber, label='BER')

+ 91 - 0
models/plots.py

@@ -0,0 +1,91 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.custom_layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+# plot frequency spectrum of e2e model
+def plot_e2e_spectrum(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=320000,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+    mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
+
+    a = np.fft.ifft(mul)
+    a2 = np.power(a, 2)
+    b = np.abs(np.fft.fft(a2))
+
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+    # plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.plot(freq, b)
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+def plot_e2e_encoded_output(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.axhline(y=0, color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.ylim((-0.1, 0.1))
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+if __name__ == '__main__':
+    # plot_e2e_spectrum()
+    plot_e2e_encoded_output()