ソースを参照

minor additions

Tharmetharan Balendran 4 年 前
コミット
b8ff3d2e85
2 ファイル変更112 行追加31 行削除
  1. 93 11
      models/end_to_end.py
  2. 19 20
      models/plots.py

+ 93 - 11
models/end_to_end.py

@@ -148,7 +148,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         return enc_weights, dec_weights
 
     def encode_stream(self, x):
-        enc_weights, dec_weights = self.extract_weights()
+        enc_weights, _ = self.extract_weights()
 
         for i in range(len(enc_weights) // 2):
             x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
@@ -160,6 +160,19 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         return x
 
+    def decode_stream(self, x):
+        _, dec_weights = self.extract_weights()
+
+        for i in range(len(dec_weights) // 2):
+            x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
+
+            if i == len(dec_weights) // 2 - 1:
+                x = tf.keras.activations.softmax(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
     def cost(self, y_true, y_pred):
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
 
@@ -198,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
@@ -217,6 +230,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         else:
             loss_fn = losses.CategoricalCrossentropy()
 
+        callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
+
         self.compile(optimizer=opt,
                      loss=loss_fn,
                      metrics=['accuracy'],
@@ -225,13 +240,18 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      )
 
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=epochs,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           callbacks=[callback],
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        if len(history.history['loss']) == epochs:
+            print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
+                  "Setting a higher epoch number and retraining is recommended")
 
     def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
@@ -434,6 +454,52 @@ def load_model(model_name=None):
     return ae_model, params
 
 
+if __name__ == 'asd':
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    lengths = np.linspace(40, 100, 7)
+    ber = []
+    for len_ in lengths:
+        optical_channel = OpticalChannel(fs=params["fs"],
+                                         num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                         dispersion_factor=params["dispersion_factor"],
+                                         fiber_length=len_,
+                                         fiber_length_stddev=params["fiber_length_stddev"],
+                                         lpf_cutoff=params["lpf_cutoff"],
+                                         rx_stddev=0,
+                                         sig_avg=0,
+                                         enob=params["enob"])
+
+        ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                       samples_per_symbol=params["samples_per_symbol"],
+                                       messages_per_block=params["messages_per_block"],
+                                       channel=optical_channel,
+                                       custom_loss_fn=params["custom_loss_fn"])
+        ae_model.train(num_of_blocks=1e5)
+        ae_model.test()
+        ber.append(ae_model.bit_error_rate)
+
+    plt.plot(lengths, ber)
+    plt.title("Bit Error Rate at different trained lengths")
+    plt.yscale('log')
+    plt.xlabel("Fiber Length / km")
+    plt.ylabel("Bit Error Rate")
+    plt.show()
+    pass
+
 if __name__ == '__main__':
 
     params = {"fs": 336e9,
@@ -452,7 +518,7 @@ if __name__ == '__main__':
 
     force_training = False
 
-    model_save_name = ""
+    model_save_name = "20210317-124015"
     param_file_path = os.path.join("exports", model_save_name, "params.json")
 
     if os.path.isfile(param_file_path) and not force_training:
@@ -480,7 +546,23 @@ if __name__ == '__main__':
         ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
         ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
     else:
-        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.train(num_of_blocks=1e4)
         ae_model.save_end_to_end()
 
+    ae_model.view_encoder()
+    ae_model.test()
+
+    # cat = [np.arange(32)]
+    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    #
+    # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
+    # inp_oh = enc.fit_transform(inp)
+    #
+    # out = ae_model(inp_oh.reshape(1, 9, 32))
+    #
+    # a = out.numpy()
+    #
+    # plt.plot(a)
+    # plt.show()
+
     pass

+ 19 - 20
models/plots.py

@@ -19,31 +19,29 @@ def plot_e2e_spectrum(model_name=None):
     out = enc.fit_transform(rand_int)
 
     # Encode the list of symbols using the trained encoder
-    a = ae_model.encode_stream(out).flatten()
+    enc = ae_model.encode_stream(out).flatten()
 
     # Pass the output of the encoder through LPF
     lpf = DigitizationLayer(fs=params["fs"],
                             num_of_samples=320000,
-                            sig_avg=0)(a).numpy()
+                            sig_avg=0)(enc).numpy()
 
     # Plot the frequency spectrum of the signal
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
     mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
 
     a = np.fft.ifft(mul)
-    a2 = np.power(a, 2)
-    b = np.abs(np.fft.fft(a2))
+    a2 = np.abs(np.power(a, 2))
+    b = np.fft.fft(a2)
 
-
-    plt.plot(freq, np.fft.fft(lpf), 'x')
-    plt.ylim((-500, 500))
-    plt.xlim((-5e10, 5e10))
-    plt.show()
-
-    # plt.plot(freq, np.fft.fft(lpf), 'x')
-    plt.plot(freq, b)
-    plt.ylim((-500, 500))
+    plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x')
+    plt.title("Spectrum of Modulating Potential at Encoder")
+    plt.ylim((0, 500))
     plt.xlim((-5e10, 5e10))
+    plt.xlabel("Freuquency / Hz")
+    plt.ylabel("Magnitude / au")
+    # plt.savefig('nn_encoder_spectrum.eps', format='eps')
+    # plt.savefig('nn_encoder_spectrum.png', format='png')
     plt.show()
 
 def plot_e2e_encoded_output(model_name=None):
@@ -76,16 +74,17 @@ def plot_e2e_encoded_output(model_name=None):
 
     for i in range(1, params["messages_per_block"]):
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
-    plt.axhline(y=0, color='black')
+
     plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
-    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
-    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
-    plt.ylim((-0.1, 0.1))
-    plt.xlim((t.min(), t.max()))
+    plt.plot(t, lpf_out.numpy().T, label='Modulating potential')
+    plt.plot(t, chan_out.numpy().flatten(), label='DD received signal')
     plt.title(str(val[0, :, 0]))
     plt.legend(loc='upper right')
+    plt.xlabel("Time / s")
+    plt.ylabel("Amplitude / V")
     plt.show()
 
+
 if __name__ == '__main__':
-    # plot_e2e_spectrum()
-    plot_e2e_encoded_output()
+    plot_e2e_spectrum()
+    # plot_e2e_encoded_output()