Bläddra i källkod

iterative model training

Tharmetharan Balendran 4 år sedan
förälder
incheckning
9f39857933
3 ändrade filer med 93 tillägg och 35 borttagningar
  1. 3 1
      models/custom_layers.py
  2. 10 10
      models/end_to_end.py
  3. 80 24
      models/new_model.py

+ 3 - 1
models/custom_layers.py

@@ -136,7 +136,9 @@ class OpticalChannel(layers.Layer):
         """
         super(OpticalChannel, self).__init__()
 
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.rx_stddev = rx_stddev
+
+        self.noise_layer = layers.GaussianNoise(self.rx_stddev)
         self.digitization_layer = DigitizationLayer(fs=fs,
                                                     num_of_samples=num_of_samples,
                                                     lpf_cutoff=lpf_cutoff,

+ 10 - 10
models/end_to_end.py

@@ -55,6 +55,11 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Boolean identifying if bit mapping is to be learnt
         self.bit_mapping = bit_mapping
 
+        # other parameters/metrics
+        self.symbol_error_rate = None
+        self.bit_error_rate = None
+        self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
+
         # Model Hyper-parameters
         leaky_relu_alpha = 0
         relu_clip_val = 1.0
@@ -192,12 +197,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  validation_data=(X_test, y_test)
                  )
 
-        plt.plot(history.history['accuracy'])
-        plt.plot(history.history['val_accuracy'])
-        plt.show()
-
-
-    def test(self, num_of_blocks=1e3):
+    def test(self, num_of_blocks=1e4):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
 
         y_out = self.call(X_test)
@@ -205,17 +205,17 @@ class EndToEndAutoencoder(tf.keras.Model):
         y_pred = tf.argmax(y_out, axis=1)
         y_true = tf.argmax(y_test, axis=1)
 
-        symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+        self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
 
         lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
 
         bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
         bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
 
-        bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+        self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
 
-        print("SYMBOL ERROR RATE: {}".format(symbol_error_rate))
-        print("BIT ERROR RATE: {}".format(bit_error_rate))
+        print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(self.bit_error_rate))
 
         pass
 

+ 80 - 24
models/new_model.py

@@ -6,6 +6,7 @@ from models.custom_layers import BitsToSymbols, SymbolsToBits
 import numpy as np
 import math
 
+from matplotlib import pyplot as plt
 
 class BitMappingModel(tf.keras.Model):
     def __init__(self,
@@ -33,6 +34,9 @@ class BitMappingModel(tf.keras.Model):
                                              channel=channel,
                                              bit_mapping=False)
 
+        self.bit_error_rate = []
+        self.symbol_error_rate = []
+
     def call(self, inputs, training=None, mask=None):
         x1 = BitsToSymbols(self.cardinality)(inputs)
         x2 = self.e2e_model(x1)
@@ -57,13 +61,38 @@ class BitMappingModel(tf.keras.Model):
 
         return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
-        """
+    def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
 
-        """
+        X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+        X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.BinaryCrossentropy(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
         for _ in range(iters):
             self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+
             self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
 
             X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
             X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
@@ -89,32 +118,59 @@ class BitMappingModel(tf.keras.Model):
                      validation_data=(X_test, y_test)
                      )
 
-    def test(self, num_of_blocks=1e4):
-        pass
-
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
 
 SAMPLING_FREQUENCY = 336e9
 CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 24
+SAMPLES_PER_SYMBOL = 32
 MESSAGES_PER_BLOCK = 9
 DISPERSION_FACTOR = -21.7 * 1e-24
 FIBER_LENGTH = 50
 
 if __name__ == '__main__':
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
-
-    model = BitMappingModel(cardinality=CARDINALITY,
-                            samples_per_symbol=SAMPLES_PER_SYMBOL,
-                            messages_per_block=MESSAGES_PER_BLOCK,
-                            channel=optical_channel)
-
-    # a , c = model.generate_random_inputs(num_of_blocks=1)
-    #
-    # a = tf.convert_to_tensor(a, dtype=tf.float32)
-    # b = model(a)
-
-    model.train(iters=1, num_of_blocks=1e4, epochs=1)
-    model.summary()
+
+    distances = [0, 10, 20, 30, 40, 50, 60]
+    ser = []
+    ber = []
+
+    baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
+    bit_rate = math.log(CARDINALITY, 2) * baud_rate
+    snr = None
+
+    for d in distances:
+
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=d)
+
+        model = BitMappingModel(cardinality=CARDINALITY,
+                                samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                messages_per_block=MESSAGES_PER_BLOCK,
+                                channel=optical_channel)
+
+        if snr is None:
+            snr = model.e2e_model.snr
+        elif snr != model.e2e_model.snr:
+            print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
+
+        # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+
+        model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
+
+        ber.append(model.bit_error_rate[-1])
+        ser.append(model.symbol_error_rate[-1])
+
+        plt.plot(model.bit_error_rate, label='BER')
+        plt.plot(model.symbol_error_rate, label='SER')
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.legend()
+        plt.show()
+        # model.summary()
+
+    plt.plot(ber, label='BER')
+    plt.plot(ser, label='SER')
+    plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    plt.legend(ber)