Quellcode durchsuchen

iterative learning

Tharmetharan Balendran vor 4 Jahren
Ursprung
Commit
a75063d665
3 geänderte Dateien mit 229 neuen und 52 gelöschten Zeilen
  1. 14 8
      models/custom_layers.py
  2. 95 44
      models/end_to_end.py
  3. 120 0
      models/new_model.py

+ 14 - 8
models/custom_layers.py

@@ -16,7 +16,8 @@ class BitsToSymbols(layers.Layer):
 
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
-        return tf.one_hot(idx, self.cardinality)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((9, 32))(out)
 
 
 class SymbolsToBits(layers.Layer):
@@ -27,10 +28,10 @@ class SymbolsToBits(layers.Layer):
         lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
 
         # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
-        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32))
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
 
     def call(self, inputs, **kwargs):
-        return tf.matmul(self.all_syms, inputs)
+        return tf.matmul(inputs, self.all_syms)
 
 
 class ExtractCentralMessage(layers.Layer):
@@ -76,7 +77,8 @@ class DigitizationLayer(layers.Layer):
                  fs,
                  num_of_samples,
                  lpf_cutoff=32e9,
-                 q_stddev=0.1):
+                 sig_avg=0.5,
+                 enob=10):
         """
         This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
         artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
@@ -88,7 +90,9 @@ class DigitizationLayer(layers.Layer):
         """
         super(DigitizationLayer, self).__init__()
 
-        self.noise_layer = layers.GaussianNoise(q_stddev)
+        stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
         freq = np.fft.fftfreq(num_of_samples, d=1/fs)
         temp = np.ones(freq.shape)
 
@@ -116,7 +120,8 @@ class OpticalChannel(layers.Layer):
                  fiber_length,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
-                 q_stddev=0.01):
+                 sig_avg=0.5,
+                 enob=10):
         """
         A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
         ADC/DAC as well as additive white gaussian noise in optical communication channels.
@@ -127,7 +132,7 @@ class OpticalChannel(layers.Layer):
         :param fiber_length: Length of fiber to model in km
         :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
         :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        :param sig_avg: Average signal amplitude
         """
         super(OpticalChannel, self).__init__()
 
@@ -135,7 +140,8 @@ class OpticalChannel(layers.Layer):
         self.digitization_layer = DigitizationLayer(fs=fs,
                                                     num_of_samples=num_of_samples,
                                                     lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
+                                                    sig_avg=sig_avg,
+                                                    enob=enob)
         self.flatten_layer = layers.Flatten()
 
         self.fs = fs

+ 95 - 44
models/end_to_end.py

@@ -3,12 +3,14 @@ import math
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import backend as K
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
 import itertools
 
+
 class EndToEndAutoencoder(tf.keras.Model):
     def __init__(self,
                  cardinality,
@@ -53,27 +55,29 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Boolean identifying if bit mapping is to be learnt
         self.bit_mapping = bit_mapping
 
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
+
         # Layer configuration for the case when bit mapping is to be learnt
         if self.bit_mapping:
             encoding_layers = [
                 layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
                 BitsToSymbols(self.cardinality),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                # layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
                 layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+                layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
             ]
             decoding_layers = [
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
                 # layers.Dense(2 * self.cardinality),
                 # layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid')
             ]
 
         # layer configuration for the case when only symbol mapping is to be learnt
@@ -81,21 +85,19 @@ class EndToEndAutoencoder(tf.keras.Model):
             encoding_layers = [
                 layers.Input(shape=(self.messages_per_block, self.cardinality)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
             ]
             decoding_layers = [
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
+                layers.Dense(self.cardinality, activation='softmax')
             ]
 
         # Encoding Neural Network
@@ -130,7 +132,7 @@ class EndToEndAutoencoder(tf.keras.Model):
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
 
             if return_vals:
-                #TODO
+                return out_arr, out_arr, out_arr[:, mid_idx, :]
 
         else:
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
@@ -145,7 +147,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-2):
+    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
@@ -159,8 +161,23 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
+        # TODO: Investigate different optimizers (with different learning rates and other parameters)
+        # SGD
+        # RMSprop
+        # Adam
+        # Adadelta
+        # Adagrad
+        # Adamax
+        # Nadam
+        # Ftrl
+
+        if self.bit_mapping:
+            loss_fn = losses.BinaryCrossentropy()
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
+
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=loss_fn,
                      metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
@@ -168,33 +185,65 @@ class EndToEndAutoencoder(tf.keras.Model):
                      )
 
         history = self.fit(x=X_train,
-                           y=y_train,
-                           batch_size=batch_size,
-                           epochs=1,
-                           shuffle=True,
-                           validation_data=(X_test, y_test)
-                           )
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
 
         plt.plot(history.history['accuracy'])
         plt.plot(history.history['val_accuracy'])
         plt.show()
 
 
+    def test(self, num_of_blocks=1e3):
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+
+        y_out = self.call(X_test)
+
+        y_pred = tf.argmax(y_out, axis=1)
+        y_true = tf.argmax(y_test, axis=1)
+
+        symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+
+        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+
+        bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+        bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
+
+        bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+
+        print("SYMBOL ERROR RATE: {}".format(symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(bit_error_rate))
+
+        pass
+
     def view_encoder(self):
         '''
         A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
         each message/symbol.
         '''
 
-        # Generate inputs for encoder
-        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
-
         mid_idx = int((self.messages_per_block - 1) / 2)
 
-        idx = 0
-        for msg in messages:
-            msg[mid_idx, idx] = 1
-            idx += 1
+        if self.bit_mapping:
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
+            lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+
+            idx = 0
+            for msg in messages:
+                msg[mid_idx] = lst[idx]
+                idx += 1
+
+        else:
+            # Generate inputs for encoder
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+            idx = 0
+            for msg in messages:
+                msg[mid_idx, idx] = 1
+                idx += 1
 
         # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
@@ -218,7 +267,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         sym_idx = 0
         for y in range(num_y):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 sym_idx += 1
 
@@ -242,11 +291,12 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Encode and flatten the messages
         enc = self.encoder(inp)
         flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
 
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
                                 num_of_samples=self.messages_per_block * self.samples_per_symbol,
-                                q_stddev=0)
+                                sig_avg=0)
 
         # Apply LPF
         lpf_out = lpf(flat_enc)
@@ -264,6 +314,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
         plt.ylim((0, 1))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
@@ -276,16 +327,15 @@ class EndToEndAutoencoder(tf.keras.Model):
         outputs = self.decoder(rx)
         return outputs
 
+
 SAMPLING_FREQUENCY = 336e9
 CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 24
+SAMPLES_PER_SYMBOL = 32
 MESSAGES_PER_BLOCK = 9
 DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 50
-
+FIBER_LENGTH = 0
 
 if __name__ == '__main__':
-
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
                                      num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      dispersion_factor=DISPERSION_FACTOR,
@@ -295,11 +345,12 @@ if __name__ == '__main__':
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    channel=optical_channel,
-                                   bit_mapping=True)
+                                   bit_mapping=False)
 
-    ae_model.train(num_of_blocks=1e6)
-    # ae_model.view_encoder()
-    # ae_model.view_sample_block()
+    ae_model.train(num_of_blocks=1e5, epochs=5)
+    ae_model.test()
+    ae_model.view_encoder()
+    ae_model.view_sample_block()
     # ae_model.summary()
     ae_model.encoder.summary()
     ae_model.channel.summary()

+ 120 - 0
models/new_model.py

@@ -0,0 +1,120 @@
+import tensorflow as tf
+from tensorflow.keras import layers, losses
+from models.custom_layers import ExtractCentralMessage, OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
+import numpy as np
+import math
+
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             bit_mapping=False)
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        """
+
+        """
+        for _ in range(iters):
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+            self.e2e_model.test()
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+    def test(self, num_of_blocks=1e4):
+        pass
+
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 24
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+
+if __name__ == '__main__':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    # a , c = model.generate_random_inputs(num_of_blocks=1)
+    #
+    # a = tf.convert_to_tensor(a, dtype=tf.float32)
+    # b = model(a)
+
+    model.train(iters=1, num_of_blocks=1e4, epochs=1)
+    model.summary()