Tharmetharan Balendran 4 лет назад
Родитель
Сommit
2d60347040
1 измененных файлов с 93 добавлено и 11 удалено
  1. 93 11
      models/end_to_end.py

+ 93 - 11
models/end_to_end.py

@@ -148,7 +148,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         return enc_weights, dec_weights
         return enc_weights, dec_weights
 
 
     def encode_stream(self, x):
     def encode_stream(self, x):
-        enc_weights, dec_weights = self.extract_weights()
+        enc_weights, _ = self.extract_weights()
 
 
         for i in range(len(enc_weights) // 2):
         for i in range(len(enc_weights) // 2):
             x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
             x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
@@ -160,6 +160,19 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return x
         return x
 
 
+    def decode_stream(self, x):
+        _, dec_weights = self.extract_weights()
+
+        for i in range(len(dec_weights) // 2):
+            x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
+
+            if i == len(dec_weights) // 2 - 1:
+                x = tf.keras.activations.softmax(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
     def cost(self, y_true, y_pred):
     def cost(self, y_true, y_pred):
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
         symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
 
 
@@ -198,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return out_arr, out_arr[:, mid_idx, :]
         return out_arr, out_arr[:, mid_idx, :]
 
 
-    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
         """
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
 
@@ -217,6 +230,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         else:
         else:
             loss_fn = losses.CategoricalCrossentropy()
             loss_fn = losses.CategoricalCrossentropy()
 
 
+        callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
+
         self.compile(optimizer=opt,
         self.compile(optimizer=opt,
                      loss=loss_fn,
                      loss=loss_fn,
                      metrics=['accuracy'],
                      metrics=['accuracy'],
@@ -225,13 +240,18 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      run_eagerly=False
                      )
                      )
 
 
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=epochs,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           callbacks=[callback],
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        if len(history.history['loss']) == epochs:
+            print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
+                  "Setting a higher epoch number and retraining is recommended")
 
 
     def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
     def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
@@ -434,6 +454,52 @@ def load_model(model_name=None):
     return ae_model, params
     return ae_model, params
 
 
 
 
+if __name__ == 'asd':
+
+    params = {"fs": 336e9,
+              "cardinality": 32,
+              "samples_per_symbol": 32,
+              "messages_per_block": 9,
+              "dispersion_factor": (-21.7 * 1e-24),
+              "fiber_length": 50,
+              "fiber_length_stddev": 1,
+              "lpf_cutoff": 32e9,
+              "rx_stddev": 0.01,
+              "sig_avg": 0.5,
+              "enob": 8,
+              "custom_loss_fn": True
+              }
+
+    lengths = np.linspace(40, 100, 7)
+    ber = []
+    for len_ in lengths:
+        optical_channel = OpticalChannel(fs=params["fs"],
+                                         num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                         dispersion_factor=params["dispersion_factor"],
+                                         fiber_length=len_,
+                                         fiber_length_stddev=params["fiber_length_stddev"],
+                                         lpf_cutoff=params["lpf_cutoff"],
+                                         rx_stddev=0,
+                                         sig_avg=0,
+                                         enob=params["enob"])
+
+        ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                       samples_per_symbol=params["samples_per_symbol"],
+                                       messages_per_block=params["messages_per_block"],
+                                       channel=optical_channel,
+                                       custom_loss_fn=params["custom_loss_fn"])
+        ae_model.train(num_of_blocks=1e5)
+        ae_model.test()
+        ber.append(ae_model.bit_error_rate)
+
+    plt.plot(lengths, ber)
+    plt.title("Bit Error Rate at different trained lengths")
+    plt.yscale('log')
+    plt.xlabel("Fiber Length / km")
+    plt.ylabel("Bit Error Rate")
+    plt.show()
+    pass
+
 if __name__ == '__main__':
 if __name__ == '__main__':
 
 
     params = {"fs": 336e9,
     params = {"fs": 336e9,
@@ -452,7 +518,7 @@ if __name__ == '__main__':
 
 
     force_training = False
     force_training = False
 
 
-    model_save_name = ""
+    model_save_name = "20210317-124015"
     param_file_path = os.path.join("exports", model_save_name, "params.json")
     param_file_path = os.path.join("exports", model_save_name, "params.json")
 
 
     if os.path.isfile(param_file_path) and not force_training:
     if os.path.isfile(param_file_path) and not force_training:
@@ -480,7 +546,23 @@ if __name__ == '__main__':
         ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
         ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
         ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
         ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
     else:
     else:
-        ae_model.train(num_of_blocks=1e5, epochs=5)
+        ae_model.train(num_of_blocks=1e4)
         ae_model.save_end_to_end()
         ae_model.save_end_to_end()
 
 
+    ae_model.view_encoder()
+    ae_model.test()
+
+    # cat = [np.arange(32)]
+    # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    #
+    # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
+    # inp_oh = enc.fit_transform(inp)
+    #
+    # out = ae_model(inp_oh.reshape(1, 9, 32))
+    #
+    # a = out.numpy()
+    #
+    # plt.plot(a)
+    # plt.show()
+
     pass
     pass