Преглед изворни кода

Merge branch 'end-to-end' into Standalone_NN_devel

# Conflicts:
#	.gitignore
#	models/end_to_end.py
#	models/layers.py
#	models/new_model.py
Min пре 4 година
родитељ
комит
eb84764f2f
4 измењених фајлова са 452 додато и 160 уклоњено
  1. 254 126
      models/end_to_end.py
  2. 27 16
      models/layers.py
  3. 80 18
      models/new_model.py
  4. 91 0
      models/plots.py

+ 254 - 126
models/end_to_end.py

@@ -1,14 +1,14 @@
-import itertools
+import json
 import math
 import math
-
+import os
+from datetime import datetime as dt
 import tensorflow as tf
 import tensorflow as tf
 import numpy as np
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 from sklearn.metrics import accuracy_score
 from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
-
-from layers import ExtractCentralMessage, BitsToSymbols, SymbolsToBits, OpticalChannel, DigitizationLayer
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
 
 
 
 
 class EndToEndAutoencoder(tf.keras.Model):
 class EndToEndAutoencoder(tf.keras.Model):
@@ -17,7 +17,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  samples_per_symbol,
                  samples_per_symbol,
                  messages_per_block,
                  messages_per_block,
                  channel,
                  channel,
-                 bit_mapping=False):
+                 custom_loss_fn=False):
         """
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -37,7 +37,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Labelled n in paper
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
         self.samples_per_symbol = samples_per_symbol
 
 
-        # Labelled N in paper
+        # Labelled N in paper - conditional +=1 to ensure odd value
         if messages_per_block % 2 == 0:
         if messages_per_block % 2 == 0:
             messages_per_block += 1
             messages_per_block += 1
         self.messages_per_block = messages_per_block
         self.messages_per_block = messages_per_block
@@ -50,71 +50,128 @@ class EndToEndAutoencoder(tf.keras.Model):
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
             ], name="channel_model")
             ], name="channel_model")
         else:
         else:
-            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+            raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
 
 
         # Boolean identifying if bit mapping is to be learnt
         # Boolean identifying if bit mapping is to be learnt
-        self.bit_mapping = bit_mapping
+        self.custom_loss_fn = custom_loss_fn
 
 
         # other parameters/metrics
         # other parameters/metrics
         self.symbol_error_rate = None
         self.symbol_error_rate = None
         self.bit_error_rate = None
         self.bit_error_rate = None
-        self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
 
 
         # Model Hyper-parameters
         # Model Hyper-parameters
         leaky_relu_alpha = 0
         leaky_relu_alpha = 0
         relu_clip_val = 1.0
         relu_clip_val = 1.0
 
 
-        # Layer configuration for the case when bit mapping is to be learnt
-        if self.bit_mapping:
-            encoding_layers = [
-                layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
-                BitsToSymbols(self.cardinality),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
-            ]
-            decoding_layers = [
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                # layers.Dense(2 * self.cardinality),
-                # layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid')
-            ]
-
-        # layer configuration for the case when only symbol mapping is to be learnt
-        else:
-            encoding_layers = [
-                layers.Input(shape=(self.messages_per_block, self.cardinality)),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
-            ]
-            decoding_layers = [
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=leaky_relu_alpha),
-                layers.Dense(self.cardinality, activation='softmax')
-            ]
-
         # Encoding Neural Network
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
         self.encoder = tf.keras.Sequential([
-            *encoding_layers
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+            # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+            # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
         ], name="encoding_model")
         ], name="encoding_model")
 
 
         # Decoding Neural Network
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
         self.decoder = tf.keras.Sequential([
-            *decoding_layers
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(2 * self.cardinality),
+            layers.LeakyReLU(alpha=leaky_relu_alpha),
+            layers.Dense(self.cardinality, activation='softmax')
         ], name="decoding_model")
         ], name="decoding_model")
 
 
+    def save_end_to_end(self):
+        # extract all params and save
+
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+        dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, dec_weights = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+
+        a = 1
+
+        return symbol_cost + a * bit_cost
+
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
         """
         A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
         A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
@@ -129,26 +186,15 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         mid_idx = int((self.messages_per_block - 1) / 2)
         mid_idx = int((self.messages_per_block - 1) / 2)
 
 
-        if self.bit_mapping:
-            rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
-
-            out = rand_int
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
 
-            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+        out = enc.fit_transform(rand_int)
 
 
-            if return_vals:
-                return out_arr, out_arr, out_arr[:, mid_idx, :]
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
 
-        else:
-            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
-
-            out = enc.fit_transform(rand_int)
-
-            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
-
-            if return_vals:
-                out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
-                return out_val, out_arr, out_arr[:, mid_idx, :]
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
 
 
         return out_arr, out_arr[:, mid_idx, :]
         return out_arr, out_arr[:, mid_idx, :]
 
 
@@ -166,18 +212,8 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
 
-        # TODO: Investigate different optimizers (with different learning rates and other parameters)
-        # SGD
-        # RMSprop
-        # Adam
-        # Adadelta
-        # Adagrad
-        # Adamax
-        # Nadam
-        # Ftrl
-
-        if self.bit_mapping:
-            loss_fn = losses.BinaryCrossentropy()
+        if self.custom_loss_fn:
+            loss_fn = self.cost
         else:
         else:
             loss_fn = losses.CategoricalCrossentropy()
             loss_fn = losses.CategoricalCrossentropy()
 
 
@@ -189,7 +225,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      run_eagerly=False
                      )
                      )
 
 
-        history = self.fit(x=X_train,
+        self.fit(x=X_train,
                  y=y_train,
                  y=y_train,
                  batch_size=batch_size,
                  batch_size=batch_size,
                  epochs=epochs,
                  epochs=epochs,
@@ -197,7 +233,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  validation_data=(X_test, y_test)
                  validation_data=(X_test, y_test)
                  )
                  )
 
 
-    def test(self, num_of_blocks=1e4):
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
 
 
         y_out = self.call(X_test)
         y_out = self.call(X_test)
@@ -207,13 +243,51 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
         self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
 
 
-        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
-
         bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
         bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
         bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
         bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
 
 
         self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
         self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
 
 
+        if (length_plot):
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+
+                y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            if plt_show:
+                plt.show()
+
         print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
         print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
         print("BIT ERROR RATE: {}".format(self.bit_error_rate))
         print("BIT ERROR RATE: {}".format(self.bit_error_rate))
 
 
@@ -227,23 +301,13 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         mid_idx = int((self.messages_per_block - 1) / 2)
         mid_idx = int((self.messages_per_block - 1) / 2)
 
 
-        if self.bit_mapping:
-            messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
-            lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
 
-            idx = 0
-            for msg in messages:
-                msg[mid_idx] = lst[idx]
-                idx += 1
-
-        else:
-            # Generate inputs for encoder
-            messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
-
-            idx = 0
-            for msg in messages:
-                msg[mid_idx, idx] = 1
-                idx += 1
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
 
 
         # Pass input through encoder and select middle messages
         # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
         encoded = self.encoder(messages)
@@ -301,6 +365,12 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Apply LPF
         # Apply LPF
         lpf_out = lpf(flat_enc)
         lpf_out = lpf(flat_enc)
 
 
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
         # Time axis
         # Time axis
         t = np.arange(self.messages_per_block * self.samples_per_symbol)
         t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
         if isinstance(self.channel.layers[1], OpticalChannel):
@@ -319,7 +389,6 @@ class EndToEndAutoencoder(tf.keras.Model):
         plt.xlim((t.min(), t.max()))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.title(str(val[0, :, 0]))
         plt.show()
         plt.show()
-        pass
 
 
     def call(self, inputs, training=None, mask=None):
     def call(self, inputs, training=None, mask=None):
         tx = self.encoder(inputs)
         tx = self.encoder(inputs)
@@ -328,31 +397,90 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
         return outputs
 
 
 
 
-SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
-DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 0
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
-
-    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                   messages_per_block=MESSAGES_PER_BLOCK,
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    force_training = False
+
+    model_save_name = ""
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
                                    channel=optical_channel,
                                    channel=optical_channel,
-                                   bit_mapping=False)
-
-    ae_model.train(num_of_blocks=1e5, epochs=5)
-    ae_model.test()
-    ae_model.view_encoder()
-    ae_model.view_sample_block()
-    # ae_model.summary()
-    ae_model.encoder.summary()
-    ae_model.channel.summary()
-    ae_model.decoder.summary()
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    if os.path.isfile(param_file_path) and not force_training:
+        ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+        ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    else:
+        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.save_end_to_end()
+
     pass
     pass

+ 27 - 16
models/layers.py

@@ -28,18 +28,19 @@ class AwgnChannel(layers.Layer):
 
 
 
 
 class BitsToSymbols(layers.Layer):
 class BitsToSymbols(layers.Layer):
-    def __init__(self, cardinality):
+    def __init__(self, cardinality, messages_per_block):
         super(BitsToSymbols, self).__init__()
         super(BitsToSymbols, self).__init__()
 
 
         self.cardinality = cardinality
         self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
 
 
         n = int(math.log(self.cardinality, 2))
         n = int(math.log(self.cardinality, 2))
-        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
 
 
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
         out = tf.one_hot(idx, self.cardinality)
         out = tf.one_hot(idx, self.cardinality)
-        return layers.Reshape((9, 32))(out)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
 
 
 
 
 class SymbolsToBits(layers.Layer):
 class SymbolsToBits(layers.Layer):
@@ -139,6 +140,7 @@ class OpticalChannel(layers.Layer):
                  num_of_samples,
                  num_of_samples,
                  dispersion_factor,
                  dispersion_factor,
                  fiber_length,
                  fiber_length,
+                 fiber_length_stddev=0,
                  lpf_cutoff=32e9,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
                  rx_stddev=0.01,
                  sig_avg=0.5,
                  sig_avg=0.5,
@@ -157,23 +159,27 @@ class OpticalChannel(layers.Layer):
         """
         """
         super(OpticalChannel, self).__init__()
         super(OpticalChannel, self).__init__()
 
 
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
+        self.lpf_cutoff = lpf_cutoff
         self.rx_stddev = rx_stddev
         self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
 
 
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
-        self.digitization_layer = DigitizationLayer(
-            fs=fs,
-            num_of_samples=num_of_samples,
-            lpf_cutoff=lpf_cutoff,
-            sig_avg=sig_avg,
-            enob=enob
-        )
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
         self.flatten_layer = layers.Flatten()
         self.flatten_layer = layers.Flatten()
 
 
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(
-            np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64)
-        self.multiplier = tf.math.exp(
-            0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq))
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
 
 
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
         # DAC LPF and noise
         # DAC LPF and noise
@@ -182,7 +188,12 @@ class OpticalChannel(layers.Layer):
         # Chromatic Dispersion
         # Chromatic Dispersion
         complex_val = tf.cast(dac_out, dtype=tf.complex64)
         complex_val = tf.cast(dac_out, dtype=tf.complex64)
         val_f = tf.signal.fft(complex_val)
         val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
         disp_t = tf.signal.ifft(disp_f)
         disp_t = tf.signal.ifft(disp_f)
 
 
         # Squared-Law Detection
         # Squared-Law Detection

+ 80 - 18
models/new_model.py

@@ -1,12 +1,14 @@
 import tensorflow as tf
 import tensorflow as tf
-from tensorflow.keras import losses
-from layers import OpticalChannel, BitsToSymbols, SymbolsToBits
-from end_to_end import EndToEndAutoencoder
+from tensorflow.keras import layers, losses
+from models.custom_layers import ExtractCentralMessage, OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
 import numpy as np
 import numpy as np
 import math
 import math
 
 
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 
 
+
 class BitMappingModel(tf.keras.Model):
 class BitMappingModel(tf.keras.Model):
     def __init__(self,
     def __init__(self,
                  cardinality,
                  cardinality,
@@ -31,13 +33,13 @@ class BitMappingModel(tf.keras.Model):
                                              samples_per_symbol=self.samples_per_symbol,
                                              samples_per_symbol=self.samples_per_symbol,
                                              messages_per_block=self.messages_per_block,
                                              messages_per_block=self.messages_per_block,
                                              channel=channel,
                                              channel=channel,
-                                             bit_mapping=False)
+                                             custom_loss_fn=False)
 
 
         self.bit_error_rate = []
         self.bit_error_rate = []
         self.symbol_error_rate = []
         self.symbol_error_rate = []
 
 
     def call(self, inputs, training=None, mask=None):
     def call(self, inputs, training=None, mask=None):
-        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
         x2 = self.e2e_model(x1)
         x2 = self.e2e_model(x1)
         out = SymbolsToBits(self.cardinality)(x2)
         out = SymbolsToBits(self.cardinality)(x2)
         return out
         return out
@@ -70,7 +72,7 @@ class BitMappingModel(tf.keras.Model):
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
 
         self.compile(optimizer=opt,
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=losses.MeanSquaredError(),
                      metrics=['accuracy'],
                      metrics=['accuracy'],
                      loss_weights=None,
                      loss_weights=None,
                      weighted_metrics=None,
                      weighted_metrics=None,
@@ -86,7 +88,9 @@ class BitMappingModel(tf.keras.Model):
                  )
                  )
 
 
     def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
     def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
-        for _ in range(iters):
+        for i in range(int(iters)):
+            print("Loop {}/{}".format(i, iters))
+
             self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
             self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
 
 
             self.e2e_model.test()
             self.e2e_model.test()
@@ -121,16 +125,33 @@ class BitMappingModel(tf.keras.Model):
             self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
             self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
             self.bit_error_rate.append(self.e2e_model.bit_error_rate)
             self.bit_error_rate.append(self.e2e_model.bit_error_rate)
 
 
+
 SAMPLING_FREQUENCY = 336e9
 SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
 DISPERSION_FACTOR = -21.7 * 1e-24
 DISPERSION_FACTOR = -21.7 * 1e-24
 FIBER_LENGTH = 50
 FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
 
 
-    distances = [0, 10, 20, 30, 40, 50, 60]
+    distances = [50]
     ser = []
     ser = []
     ber = []
     ber = []
 
 
@@ -143,7 +164,9 @@ if __name__ == '__main__':
         optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
         optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
                                          num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                          num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                          dispersion_factor=DISPERSION_FACTOR,
                                          dispersion_factor=DISPERSION_FACTOR,
-                                         fiber_length=d)
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
 
 
         model = BitMappingModel(cardinality=CARDINALITY,
         model = BitMappingModel(cardinality=CARDINALITY,
                                 samples_per_symbol=SAMPLES_PER_SYMBOL,
                                 samples_per_symbol=SAMPLES_PER_SYMBOL,
@@ -155,18 +178,57 @@ if __name__ == '__main__':
         elif snr != model.e2e_model.snr:
         elif snr != model.e2e_model.snr:
             print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
             print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
 
 
-        # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
 
 
         model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
         model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
 
 
+        model.e2e_model.test(length_plot=True)
+
         ber.append(model.bit_error_rate[-1])
         ber.append(model.bit_error_rate[-1])
         ser.append(model.symbol_error_rate[-1])
         ser.append(model.symbol_error_rate[-1])
 
 
-        # plt.plot(model.bit_error_rate, label='BER')
-        # plt.plot(model.symbol_error_rate, label='SER')
-        # plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
-        # plt.legend()
-        # plt.show()
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        custom_loss_fn=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
+        plt.legend()
+        plt.show()
         # model.summary()
         # model.summary()
 
 
     # plt.plot(ber, label='BER')
     # plt.plot(ber, label='BER')

+ 91 - 0
models/plots.py

@@ -0,0 +1,91 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.custom_layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+# plot frequency spectrum of e2e model
+def plot_e2e_spectrum(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=320000,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+    mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
+
+    a = np.fft.ifft(mul)
+    a2 = np.power(a, 2)
+    b = np.abs(np.fft.fft(a2))
+
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+    # plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.plot(freq, b)
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+def plot_e2e_encoded_output(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.axhline(y=0, color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.ylim((-0.1, 0.1))
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+if __name__ == '__main__':
+    # plot_e2e_spectrum()
+    plot_e2e_encoded_output()