Kaynağa Gözat

ber vs length plot

Tharmetharan Balendran 4 yıl önce
ebeveyn
işleme
16045ea3e0
3 değiştirilmiş dosya ile 174 ekleme ve 36 silme
  1. 17 10
      models/custom_layers.py
  2. 48 11
      models/end_to_end.py
  3. 109 15
      models/new_model.py

+ 17 - 10
models/custom_layers.py

@@ -6,10 +6,11 @@ import itertools
 
 
 class BitsToSymbols(layers.Layer):
-    def __init__(self, cardinality):
+    def __init__(self, cardinality, messages_per_block):
         super(BitsToSymbols, self).__init__()
 
         self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
 
         n = int(math.log(self.cardinality, 2))
         self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
@@ -17,7 +18,7 @@ class BitsToSymbols(layers.Layer):
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
         out = tf.one_hot(idx, self.cardinality)
-        return layers.Reshape((9, 32))(out)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
 
 
 class SymbolsToBits(layers.Layer):
@@ -136,19 +137,25 @@ class OpticalChannel(layers.Layer):
         """
         super(OpticalChannel, self).__init__()
 
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = fiber_length
+        self.lpf_cutoff = lpf_cutoff
         self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
 
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    sig_avg=sig_avg,
-                                                    enob=enob)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
         self.flatten_layer = layers.Flatten()
 
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
 
     def call(self, inputs, **kwargs):
         # DAC LPF and noise

+ 48 - 11
models/end_to_end.py

@@ -58,7 +58,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         # other parameters/metrics
         self.symbol_error_rate = None
         self.bit_error_rate = None
-        self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
 
         # Model Hyper-parameters
         leaky_relu_alpha = 0
@@ -190,14 +190,14 @@ class EndToEndAutoencoder(tf.keras.Model):
                      )
 
         history = self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=epochs,
-                 shuffle=True,
-                 validation_data=(X_test, y_test)
-                 )
-
-    def test(self, num_of_blocks=1e4):
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=epochs,
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+    def test(self, num_of_blocks=1e4, length_plot=False):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
 
         y_out = self.call(X_test)
@@ -207,13 +207,50 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
 
-        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
-
         bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
         bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
 
         self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
 
+        if (length_plot):
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+
+                y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            plt.show()
+
         print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
         print("BIT ERROR RATE: {}".format(self.bit_error_rate))
 

+ 109 - 15
models/new_model.py

@@ -8,6 +8,7 @@ import math
 
 from matplotlib import pyplot as plt
 
+
 class BitMappingModel(tf.keras.Model):
     def __init__(self,
                  cardinality,
@@ -38,7 +39,7 @@ class BitMappingModel(tf.keras.Model):
         self.symbol_error_rate = []
 
     def call(self, inputs, training=None, mask=None):
-        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
         x2 = self.e2e_model(x1)
         out = SymbolsToBits(self.cardinality)(x2)
         return out
@@ -71,7 +72,8 @@ class BitMappingModel(tf.keras.Model):
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     # loss=losses.BinaryCrossentropy(),
+                     loss=losses.MeanSquaredError(),
                      metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
@@ -87,7 +89,9 @@ class BitMappingModel(tf.keras.Model):
                  )
 
     def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
-        for _ in range(iters):
+        for i in range(int(0.5*iters)):
+            print("Loop {}/{}".format(i, iters))
+
             self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
 
             self.e2e_model.test()
@@ -104,6 +108,7 @@ class BitMappingModel(tf.keras.Model):
 
             self.compile(optimizer=opt,
                          loss=losses.BinaryCrossentropy(),
+                         # loss=losses.MeanSquaredError(),
                          metrics=['accuracy'],
                          loss_weights=None,
                          weighted_metrics=None,
@@ -122,16 +127,64 @@ class BitMappingModel(tf.keras.Model):
             self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
             self.bit_error_rate.append(self.e2e_model.bit_error_rate)
 
+        for i in range(int(0.5*iters)):
+
+            X_train, y_train = self.generate_random_ inputs(int(1e5 * train_size))
+            X_test, y_test = self.generate_random_inputs(int(1e5 * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         # loss=losses.MeanSquaredError(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+
 SAMPLING_FREQUENCY = 336e9
-CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 32
-MESSAGES_PER_BLOCK = 9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
 DISPERSION_FACTOR = -21.7 * 1e-24
 FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
 
 if __name__ == '__main__':
 
-    distances = [0, 10, 20, 30, 40, 50, 60]
+    distances = [50]
     ser = []
     ber = []
 
@@ -144,7 +197,9 @@ if __name__ == '__main__':
         optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
                                          num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                          dispersion_factor=DISPERSION_FACTOR,
-                                         fiber_length=d)
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
 
         model = BitMappingModel(cardinality=CARDINALITY,
                                 samples_per_symbol=SAMPLES_PER_SYMBOL,
@@ -156,21 +211,60 @@ if __name__ == '__main__':
         elif snr != model.e2e_model.snr:
             print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
 
-        # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
 
         model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
 
+        model.e2e_model.test(length_plot=True)
+
         ber.append(model.bit_error_rate[-1])
         ser.append(model.symbol_error_rate[-1])
 
-        plt.plot(model.bit_error_rate, label='BER')
-        plt.plot(model.symbol_error_rate, label='SER')
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        bit_mapping=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
         plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
         plt.legend()
         plt.show()
         # model.summary()
 
-    plt.plot(ber, label='BER')
-    plt.plot(ser, label='SER')
-    plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
-    plt.legend(ber)
+    # plt.plot(ber, label='BER')
+    # plt.plot(ser, label='SER')
+    # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    # plt.legend(ber)