Tharmetharan Balendran 4 лет назад
Родитель
Сommit
b8ff3d2e85
2 измененных файлов с 112 добавлено и 31 удалено
  1. 93 11
      models/end_to_end.py
  2. 19 20
      models/plots.py

+ 93 - 11
models/end_to_end.py

@@ -148,7 +148,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         return enc_weights, dec_weights
         return enc_weights, dec_weights
 
 
     def encode_stream(self, x):
     def encode_stream(self, x):
-        enc_weights, dec_weights = self.extract_weights()
+        enc_weights, _ = self.extract_weights()
 
 
         for i in range(len(enc_weights) // 2):
         for i in range(len(enc_weights) // 2):
             x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
             x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
@@ -160,6 +160,19 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return x
         return x
 
 
+    def decode_stream(self, x):
+        _, dec_weights = self.extract_weights()
+
+        for i in range(len(dec_weights) // 2):
+            x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
+
+            if i == len(dec_weights) // 2 - 1:
+                x = tf.keras.activations.softmax(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
     def cost(self, y_true, y_pred):
     def cost(self, y_true, y_pred):
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
 
 
@@ -198,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return out_arr, out_arr[:, mid_idx, :]
         return out_arr, out_arr[:, mid_idx, :]
 
 
-    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
         """
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
 
@@ -217,6 +230,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         else:
         else:
             loss_fn = losses.CategoricalCrossentropy()
             loss_fn = losses.CategoricalCrossentropy()
 
 
+        callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
+
         self.compile(optimizer=opt,
         self.compile(optimizer=opt,
                      loss=loss_fn,
                      loss=loss_fn,
                      metrics=['accuracy'],
                      metrics=['accuracy'],
@@ -225,13 +240,18 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      run_eagerly=False
                      )
                      )
 
 
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=epochs,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           callbacks=[callback],
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        if len(history.history['loss']) == epochs:
+            print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
+                  "Setting a higher epoch number and retraining is recommended")
 
 
     def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
     def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
@@ -434,6 +454,52 @@ def load_model(model_name=None):
     return ae_model, params
     return ae_model, params
 
 
 
 
+if __name__ == 'asd':
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    lengths = np.linspace(40, 100, 7)
+    ber = []
+    for len_ in lengths:
+        optical_channel = OpticalChannel(fs=params["fs"],
+                                         num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                         dispersion_factor=params["dispersion_factor"],
+                                         fiber_length=len_,
+                                         fiber_length_stddev=params["fiber_length_stddev"],
+                                         lpf_cutoff=params["lpf_cutoff"],
+                                         rx_stddev=0,
+                                         sig_avg=0,
+                                         enob=params["enob"])
+
+        ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                       samples_per_symbol=params["samples_per_symbol"],
+                                       messages_per_block=params["messages_per_block"],
+                                       channel=optical_channel,
+                                       custom_loss_fn=params["custom_loss_fn"])
+        ae_model.train(num_of_blocks=1e5)
+        ae_model.test()
+        ber.append(ae_model.bit_error_rate)
+
+    plt.plot(lengths, ber)
+    plt.title("Bit Error Rate at different trained lengths")
+    plt.yscale('log')
+    plt.xlabel("Fiber Length / km")
+    plt.ylabel("Bit Error Rate")
+    plt.show()
+    pass
+
 if __name__ == '__main__':
 if __name__ == '__main__':
 
 
     params = {"fs": 336e9,
     params = {"fs": 336e9,
@@ -452,7 +518,7 @@ if __name__ == '__main__':
 
 
     force_training = False
     force_training = False
 
 
-    model_save_name = ""
+    model_save_name = "20210317-124015"
     param_file_path = os.path.join("exports", model_save_name, "params.json")
     param_file_path = os.path.join("exports", model_save_name, "params.json")
 
 
     if os.path.isfile(param_file_path) and not force_training:
     if os.path.isfile(param_file_path) and not force_training:
@@ -480,7 +546,23 @@ if __name__ == '__main__':
         ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
         ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
         ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
         ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
     else:
     else:
-        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.train(num_of_blocks=1e4)
         ae_model.save_end_to_end()
         ae_model.save_end_to_end()
 
 
+    ae_model.view_encoder()
+    ae_model.test()
+
+    # cat = [np.arange(32)]
+    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    #
+    # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
+    # inp_oh = enc.fit_transform(inp)
+    #
+    # out = ae_model(inp_oh.reshape(1, 9, 32))
+    #
+    # a = out.numpy()
+    #
+    # plt.plot(a)
+    # plt.show()
+
     pass
     pass

+ 19 - 20
models/plots.py

@@ -19,31 +19,29 @@ def plot_e2e_spectrum(model_name=None):
     out = enc.fit_transform(rand_int)
     out = enc.fit_transform(rand_int)
 
 
     # Encode the list of symbols using the trained encoder
     # Encode the list of symbols using the trained encoder
-    a = ae_model.encode_stream(out).flatten()
+    enc = ae_model.encode_stream(out).flatten()
 
 
     # Pass the output of the encoder through LPF
     # Pass the output of the encoder through LPF
     lpf = DigitizationLayer(fs=params["fs"],
     lpf = DigitizationLayer(fs=params["fs"],
                             num_of_samples=320000,
                             num_of_samples=320000,
-                            sig_avg=0)(a).numpy()
+                            sig_avg=0)(enc).numpy()
 
 
     # Plot the frequency spectrum of the signal
     # Plot the frequency spectrum of the signal
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
     mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
     mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
 
 
     a = np.fft.ifft(mul)
     a = np.fft.ifft(mul)
-    a2 = np.power(a, 2)
-    b = np.abs(np.fft.fft(a2))
+    a2 = np.abs(np.power(a, 2))
+    b = np.fft.fft(a2)
 
 
-
-    plt.plot(freq, np.fft.fft(lpf), 'x')
-    plt.ylim((-500, 500))
-    plt.xlim((-5e10, 5e10))
-    plt.show()
-
-    # plt.plot(freq, np.fft.fft(lpf), 'x')
-    plt.plot(freq, b)
-    plt.ylim((-500, 500))
+    plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x')
+    plt.title("Spectrum of Modulating Potential at Encoder")
+    plt.ylim((0, 500))
     plt.xlim((-5e10, 5e10))
     plt.xlim((-5e10, 5e10))
+    plt.xlabel("Freuquency / Hz")
+    plt.ylabel("Magnitude / au")
+    # plt.savefig('nn_encoder_spectrum.eps', format='eps')
+    # plt.savefig('nn_encoder_spectrum.png', format='png')
     plt.show()
     plt.show()
 
 
 def plot_e2e_encoded_output(model_name=None):
 def plot_e2e_encoded_output(model_name=None):
@@ -76,16 +74,17 @@ def plot_e2e_encoded_output(model_name=None):
 
 
     for i in range(1, params["messages_per_block"]):
     for i in range(1, params["messages_per_block"]):
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
-    plt.axhline(y=0, color='black')
+
     plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
     plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
-    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
-    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
-    plt.ylim((-0.1, 0.1))
-    plt.xlim((t.min(), t.max()))
+    plt.plot(t, lpf_out.numpy().T, label='Modulating potential')
+    plt.plot(t, chan_out.numpy().flatten(), label='DD received signal')
     plt.title(str(val[0, :, 0]))
     plt.title(str(val[0, :, 0]))
     plt.legend(loc='upper right')
     plt.legend(loc='upper right')
+    plt.xlabel("Time / s")
+    plt.ylabel("Amplitude / V")
     plt.show()
     plt.show()
 
 
+
 if __name__ == '__main__':
 if __name__ == '__main__':
-    # plot_e2e_spectrum()
-    plot_e2e_encoded_output()
+    plot_e2e_spectrum()
+    # plot_e2e_encoded_output()