Browse Source

AE training and testing

Tharmetharan Balendran 5 years ago
parent
commit
25393c2ae8
1 changed files with 122 additions and 46 deletions
  1. 122 46
      models/end_to_end.py

+ 122 - 46
models/end_to_end.py

@@ -1,15 +1,14 @@
-# End to end Autoencoder completely implemented in Keras/TF
-
 import keras
-import math
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
+from sklearn.preprocessing import OneHotEncoder
+
 
-from keras import layers
+from keras import layers, losses
 
 
-class ExtractCentralMessage(keras.layers.Layer):
+class ExtractCentralMessage(layers.Layer):
     def __init__(self, neighbouring_blocks, samples_per_symbol):
         super(ExtractCentralMessage, self).__init__()
 
@@ -25,37 +24,56 @@ class ExtractCentralMessage(keras.layers.Layer):
         return tf.matmul(inputs, self.w)
 
 
-class AwgnChannel(keras.layers.Layer):
+class AwgnChannel(layers.Layer):
     def __init__(self, stddev=0.1):
         super(AwgnChannel, self).__init__()
-        self.stddev = stddev
         self.noise_layer = layers.GaussianNoise(stddev)
+        self.flatten_layer = layers.Flatten()
 
     def call(self, inputs):
-        serialized = layers.Flatten(inputs)
+        serialized = self.flatten_layer(inputs)
         return self.noise_layer.call(serialized, training=True)
 
 
-class OpticalChannel(keras.layers.Layer):
+class DigitizationLayer(layers.Layer):
+    def __init__(self, stddev=0.1):
+        super(DigitizationLayer, self).__init__()
+        self.noise_layer = layers.GaussianNoise(stddev)
+
+    def call(self, inputs):
+        # TODO:
+        #  Low-pass filter (convolution with filter h(t))
+        return self.noise_layer.call(inputs, training=True)
+
+
+class OpticalChannel(layers.Layer):
     def __init__(self, fs, stddev=0.1):
         super(OpticalChannel, self).__init__()
         self.noise_layer = layers.GaussianNoise(stddev)
+        self.digitization_layer = DigitizationLayer()
+        self.flatten_layer = layers.Flatten()
         self.fs = fs
 
     def call(self, inputs):
-        # TODO:
-        #  Low-pass filter & digitization noise for DAC and ADC (probably as an external layer called twice)
-
         # Serializing outputs of all blocks
-        serialized = layers.Flatten(inputs)
+        serialized = self.flatten_layer(inputs)
+
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(serialized)
+
+        # TODO:
+        #  Chromatic Dispersion (fft -> phase shift -> ifft)
 
         # Squared-Law Detection
-        squared = tf.square(tf.abs(serialized))
+        pd_out = tf.square(tf.abs(dac_out))
 
-        # Adding gaussian noise
-        noisy = self.noise_layer.call(squared, training=True)
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(pd_out, training=True)
 
-        return noisy
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out
 
 
 class EndToEndAutoencoder(tf.keras.Model):
@@ -64,44 +82,92 @@ class EndToEndAutoencoder(tf.keras.Model):
                  samples_per_symbol,
                  neighbouring_blocks,
                  oversampling,
-                 channel_name='awgn'):
-
-        # Number of leading/following messages
+                 channel):
+        super(EndToEndAutoencoder, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+        # Labelled N in paper
         if neighbouring_blocks % 2 == 0:
             neighbouring_blocks += 1
+        self.neighbouring_blocks = neighbouring_blocks
         # Oversampling rate
-        oversampling = int(oversampling)
-        # Transmission Channel : ['awgn', 'optical']
-        channel_name = channel_name.strip().lower()
+        self.oversampling = int(oversampling)
+        # Channel Model Layer
+        if isinstance(channel, layers.Layer):
+            self.channel = channel
+        else:
+            raise TypeError("Channel must be a subclass of keras.layers.layer!")
 
+        # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            layers.Input(shape=(neighbouring_blocks, cardinality)),
-            layers.Dense(2 * cardinality, activation='relu'),
-            layers.Dense(2 * cardinality, activation='relu'),
-            layers.Dense(samples_per_symbol),
+            layers.Input(shape=(self.neighbouring_blocks, self.cardinality)),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.samples_per_symbol),
             layers.ReLU(max_value=1.0)
         ])
 
-        # Channel Model Layer
-        if channel_name == 'optical':
-            self.channel = OpticalChannel()
-        elif channel_name == 'awgn':
-            self.channel = AwgnChannel()
-        else:
-            raise TypeError("{} is not an accepted Channel Model.".format(channel_name))
-
-        self.encoder = tf.keras.Sequential([
-            ExtractCentralMessage(neighbouring_blocks, samples_per_symbol),
-            layers.Dense(samples_per_symbol, activation='relu'),
-            layers.Dense(2 * cardinality, activation='relu'),
-            layers.Dense(2 * cardinality, activation='relu'),
-            layers.Dense(cardinality, activation='softmax')
+        # Decoding Neural Network
+        self.decoder = tf.keras.Sequential([
+            ExtractCentralMessage(self.neighbouring_blocks, self.samples_per_symbol),
+            layers.Dense(self.samples_per_symbol, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.cardinality, activation='softmax')
         ])
 
+    def generate_random_inputs(self, num_of_blocks):
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.neighbouring_blocks, 1))
+
+        # cat = np.reshape(np.arange(self.cardinality), (1, -1))
+        enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
+
+        out = enc.fit_transform(rand_int)
+        out_arr = np.reshape(out, (num_of_blocks, self.neighbouring_blocks, self.cardinality))
+
+        mid_idx = int((self.neighbouring_blocks-1)/2)
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+
+    def train(self, num_of_blocks=1e6, train_size=0.8):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+
+        self.compile(optimizer='adam',
+                     loss=losses.BinaryCrossentropy(),
+                     metrics=None,
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=None
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=None,
+                 epochs=1,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+
+
     def view_encoder(self):
-        # TODO
-        #  Visualize encoder
-        pass
+        messages = np.zeros((self.cardinality, self.neighbouring_blocks, self.cardinality))
+
+        mid_idx = int((self.neighbouring_blocks-1)/2)
+
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
+
+        encoded = self.encoder(messages)
+        return messages, encoded[:, mid_idx, :]
+
 
     def call(self, x):
         tx = self.encoder(x)
@@ -111,6 +177,16 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
 if __name__ == '__main__':
-    # TODO
-    #  training/testing of autoencoder
+    tx_channel = AwgnChannel(stddev=0.1)
+
+    model = EndToEndAutoencoder(cardinality=8,
+                                samples_per_symbol=10,
+                                neighbouring_blocks=5,
+                                oversampling=4,
+                                channel=tx_channel)
+
+    model.train()
+
+    model.view_encoder()
+
     pass