Tharmetharan Balendran 4 anos atrás
pai
commit
0b990b8b42
3 arquivos alterados com 267 adições e 61 exclusões
  1. 4 1
      .gitignore
  2. 172 60
      models/end_to_end.py
  3. 91 0
      models/plots.py

+ 4 - 1
.gitignore

@@ -5,4 +5,7 @@ __pycache__
 
 # Environments
 venv/
-tests/local_test.py
+tests/local_test.py
+
+# Model export files
+models/exports

+ 172 - 60
models/end_to_end.py

@@ -1,5 +1,7 @@
+import json
 import math
-
+import os
+from datetime import datetime as dt
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
@@ -62,8 +64,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         leaky_relu_alpha = 0
         relu_clip_val = 1.0
 
-        # layer configuration
-        encoding_layers = [
+        # Encoding Neural Network
+        self.encoder = tf.keras.Sequential([
             layers.Input(shape=(self.messages_per_block, self.cardinality)),
             layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
             layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
@@ -72,24 +74,91 @@ class EndToEndAutoencoder(tf.keras.Model):
             layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
             # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
             # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
-        ]
-        decoding_layers = [
+        ], name="encoding_model")
+
+        # Decoding Neural Network
+        self.decoder = tf.keras.Sequential([
             layers.Dense(2 * self.cardinality),
             layers.LeakyReLU(alpha=leaky_relu_alpha),
             layers.Dense(2 * self.cardinality),
             layers.LeakyReLU(alpha=leaky_relu_alpha),
             layers.Dense(self.cardinality, activation='softmax')
-        ]
+        ], name="decoding_model")
 
-        # Encoding Neural Network
-        self.encoder = tf.keras.Sequential([
-            *encoding_layers
-        ], name="encoding_model")
+    def save_end_to_end(self):
+        # extract all params and save
 
-        # Decoding Neural Network
-        self.decoder = tf.keras.Sequential([
-            *decoding_layers
-        ], name="decoding_model")
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+        dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, dec_weights = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
 
     def cost(self, y_true, y_pred):
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
@@ -296,6 +365,12 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Apply LPF
         lpf_out = lpf(flat_enc)
 
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
         # Time axis
         t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
@@ -314,7 +389,6 @@ class EndToEndAutoencoder(tf.keras.Model):
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.show()
-        pass
 
     def call(self, inputs, training=None, mask=None):
         tx = self.encoder(inputs)
@@ -323,52 +397,90 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
 
 
-SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
-DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 50
-FIBER_LENGTH_STDDEV = 5
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
 
 if __name__ == '__main__':
 
-    stddevs = [0, 1, 5, 10]
-    legend = []
-
-    for s in stddevs:
-        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
-                                         dispersion_factor=DISPERSION_FACTOR,
-                                         fiber_length=FIBER_LENGTH,
-                                         fiber_length_stddev=s,
-                                         lpf_cutoff=32e9,
-                                         rx_stddev=0.01,
-                                         sig_avg=0.5,
-                                         enob=10)
-
-        ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                       samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                       messages_per_block=MESSAGES_PER_BLOCK,
-                                       channel=optical_channel,
-                                       custom_loss_fn=True)
-
-        print(ae_model.snr)
-
-        ae_model.train(num_of_blocks=3e5, epochs=5)
-        ae_model.test(length_plot=True, plt_show=False)
-        # plt.legend(['{} +/- {}'.format(FIBER_LENGTH, s)])
-
-        legend.append('{} +/- {}'.format(FIBER_LENGTH, s))
-
-    plt.legend(legend)
-    plt.show()
-    plt.savefig('ber_vs_length.eps', format='eps')
-
-    # ae_model.view_encoder()
-    # ae_model.view_sample_block()
-    # # ae_model.summary()
-    # ae_model.encoder.summary()
-    # ae_model.channel.summary()
-    # ae_model.decoder.summary()
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    force_training = False
+
+    model_save_name = ""
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    if os.path.isfile(param_file_path) and not force_training:
+        ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+        ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    else:
+        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.save_end_to_end()
+
     pass

+ 91 - 0
models/plots.py

@@ -0,0 +1,91 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.custom_layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+# plot frequency spectrum of e2e model
+def plot_e2e_spectrum(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=320000,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+    mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
+
+    a = np.fft.ifft(mul)
+    a2 = np.power(a, 2)
+    b = np.abs(np.fft.fft(a2))
+
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+    # plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.plot(freq, b)
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+def plot_e2e_encoded_output(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.axhline(y=0, color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.ylim((-0.1, 0.1))
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+if __name__ == '__main__':
+    # plot_e2e_spectrum()
+    plot_e2e_encoded_output()