Kaynağa Gözat

variable length training added

Tharmetharan Balendran 4 yıl önce
ebeveyn
işleme
8224f88367
2 değiştirilmiş dosya ile 70 ekleme ve 26 silme
  1. 11 3
      models/custom_layers.py
  2. 59 23
      models/end_to_end.py

+ 11 - 3
models/custom_layers.py

@@ -119,6 +119,7 @@ class OpticalChannel(layers.Layer):
                  num_of_samples,
                  dispersion_factor,
                  fiber_length,
+                 fiber_length_stddev=0,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
                  sig_avg=0.5,
@@ -140,13 +141,15 @@ class OpticalChannel(layers.Layer):
         self.fs = fs
         self.num_of_samples = num_of_samples
         self.dispersion_factor = dispersion_factor
-        self.fiber_length = fiber_length
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
         self.lpf_cutoff = lpf_cutoff
         self.rx_stddev = rx_stddev
         self.sig_avg = sig_avg
         self.enob = enob
 
         self.noise_layer = layers.GaussianNoise(self.rx_stddev)
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
         self.digitization_layer = DigitizationLayer(fs=self.fs,
                                                     num_of_samples=self.num_of_samples,
                                                     lpf_cutoff=self.lpf_cutoff,
@@ -155,7 +158,7 @@ class OpticalChannel(layers.Layer):
         self.flatten_layer = layers.Flatten()
 
         self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
-        self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
 
     def call(self, inputs, **kwargs):
         # DAC LPF and noise
@@ -164,7 +167,12 @@ class OpticalChannel(layers.Layer):
         # Chromatic Dispersion
         complex_val = tf.cast(dac_out, dtype=tf.complex64)
         val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
         disp_t = tf.signal.ifft(disp_f)
 
         # Squared-Law Detection

+ 59 - 23
models/end_to_end.py

@@ -115,6 +115,18 @@ class EndToEndAutoencoder(tf.keras.Model):
             *decoding_layers
         ], name="decoding_model")
 
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+
+        a = 1
+
+        return symbol_cost + a * bit_cost
+
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
         A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
@@ -179,7 +191,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         if self.bit_mapping:
             loss_fn = losses.BinaryCrossentropy()
         else:
-            loss_fn = losses.CategoricalCrossentropy()
+            # loss_fn = losses.CategoricalCrossentropy()
+            loss_fn = self.cost
 
         self.compile(optimizer=opt,
                      loss=loss_fn,
@@ -197,7 +210,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                            validation_data=(X_test, y_test)
                            )
 
-    def test(self, num_of_blocks=1e4, length_plot=False):
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
 
         y_out = self.call(X_test)
@@ -249,7 +262,8 @@ class EndToEndAutoencoder(tf.keras.Model):
 
             plt.plot(lengths, ber_l)
             plt.yscale('log')
-            plt.show()
+            if plt_show:
+                plt.show()
 
         print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
         print("BIT ERROR RATE: {}".format(self.bit_error_rate))
@@ -370,26 +384,48 @@ CARDINALITY = 32
 SAMPLES_PER_SYMBOL = 32
 MESSAGES_PER_BLOCK = 9
 DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 0
+FIBER_LENGTH = 50
+FIBER_LENGTH_STDDEV = 5
+
 
 if __name__ == '__main__':
-    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
-                                     dispersion_factor=DISPERSION_FACTOR,
-                                     fiber_length=FIBER_LENGTH)
-
-    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
-                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
-                                   messages_per_block=MESSAGES_PER_BLOCK,
-                                   channel=optical_channel,
-                                   bit_mapping=False)
-
-    ae_model.train(num_of_blocks=1e5, epochs=5)
-    ae_model.test()
-    ae_model.view_encoder()
-    ae_model.view_sample_block()
-    # ae_model.summary()
-    ae_model.encoder.summary()
-    ae_model.channel.summary()
-    ae_model.decoder.summary()
+
+    stddevs = [0, 1, 5, 10]
+    legend = []
+
+    for s in stddevs:
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=FIBER_LENGTH,
+                                         fiber_length_stddev=s,
+                                         lpf_cutoff=32e9,
+                                         rx_stddev=0.01,
+                                         sig_avg=0.5,
+                                         enob=10)
+
+        ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                       samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                       messages_per_block=MESSAGES_PER_BLOCK,
+                                       channel=optical_channel,
+                                       bit_mapping=False)
+
+        print(ae_model.snr)
+
+        ae_model.train(num_of_blocks=3e5, epochs=5)
+        ae_model.test(length_plot=True, plt_show=False)
+        # plt.legend(['{} +/- {}'.format(FIBER_LENGTH, s)])
+
+        legend.append('{} +/- {}'.format(FIBER_LENGTH, s))
+
+    plt.legend(legend)
+    plt.show()
+    plt.savefig('ber_vs_length.eps', format='eps')
+
+    # ae_model.view_encoder()
+    # ae_model.view_sample_block()
+    # # ae_model.summary()
+    # ae_model.encoder.summary()
+    # ae_model.channel.summary()
+    # ae_model.decoder.summary()
     pass