Przeglądaj źródła

Working implementation of end-to-end AE

Tharmetharan Balendran 5 lat temu
rodzic
commit
be6243754e
1 zmienionych plików z 224 dodań i 72 usunięć
  1. 224 72
      models/end_to_end.py

+ 224 - 72
models/end_to_end.py

@@ -1,74 +1,129 @@
+import math
+
 import keras
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
+from matplotlib import collections as matcoll
 from sklearn.preprocessing import OneHotEncoder
-
-
 from keras import layers, losses
 
 
 class ExtractCentralMessage(layers.Layer):
-    def __init__(self, neighbouring_blocks, samples_per_symbol):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
         super(ExtractCentralMessage, self).__init__()
 
-        temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol))
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
         i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2))
-        end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2))
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
         temp_w[begin:end, :] = i
 
         self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
 
-    def call(self, inputs):
+    def call(self, inputs, **kwargs):
         return tf.matmul(inputs, self.w)
 
 
 class AwgnChannel(layers.Layer):
-    def __init__(self, stddev=0.1):
+    def __init__(self, rx_stddev=0.1):
+        """
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
         super(AwgnChannel, self).__init__()
-        self.noise_layer = layers.GaussianNoise(stddev)
-        self.flatten_layer = layers.Flatten()
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
 
-    def call(self, inputs):
-        serialized = self.flatten_layer(inputs)
-        return self.noise_layer.call(serialized, training=True)
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
 
 
 class DigitizationLayer(layers.Layer):
-    def __init__(self, stddev=0.1):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 q_stddev=0.1):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
         super(DigitizationLayer, self).__init__()
-        self.noise_layer = layers.GaussianNoise(stddev)
 
-    def call(self, inputs):
-        # TODO:
-        #  Low-pass filter (convolution with filter h(t))
-        return self.noise_layer.call(inputs, training=True)
+        self.noise_layer = layers.GaussianNoise(q_stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
 
 
 class OpticalChannel(layers.Layer):
-    def __init__(self, fs, stddev=0.1):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 q_stddev=0.01):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
         super(OpticalChannel, self).__init__()
-        self.noise_layer = layers.GaussianNoise(stddev)
-        self.digitization_layer = DigitizationLayer()
+
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.digitization_layer = DigitizationLayer(fs=fs,
+                                                    num_of_samples=num_of_samples,
+                                                    lpf_cutoff=lpf_cutoff,
+                                                    q_stddev=q_stddev)
         self.flatten_layer = layers.Flatten()
-        self.fs = fs
 
-    def call(self, inputs):
-        # Serializing outputs of all blocks
-        serialized = self.flatten_layer(inputs)
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
+        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
 
+    def call(self, inputs, **kwargs):
         # DAC LPF and noise
-        dac_out = self.digitization_layer(serialized)
+        dac_out = self.digitization_layer(inputs)
 
-        # TODO:
-        #  Chromatic Dispersion (fft -> phase shift -> ifft)
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex128)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
 
         # Squared-Law Detection
-        pd_out = tf.square(tf.abs(dac_out))
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
 
         # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(pd_out, training=True)
+        rx_signal = self.noise_layer.call(real_val, training=True)
 
         # ADC LPF and noise
         adc_out = self.digitization_layer(rx_signal)
@@ -80,9 +135,14 @@ class EndToEndAutoencoder(tf.keras.Model):
     def __init__(self,
                  cardinality,
                  samples_per_symbol,
-                 neighbouring_blocks,
-                 oversampling,
+                 messages_per_block,
                  channel):
+        """
+        :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        :param messages_per_block: Total number of messages in transmission block
+        :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
+        """
         super(EndToEndAutoencoder, self).__init__()
 
         # Labelled M in paper
@@ -90,20 +150,22 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
         # Labelled N in paper
-        if neighbouring_blocks % 2 == 0:
-            neighbouring_blocks += 1
-        self.neighbouring_blocks = neighbouring_blocks
-        # Oversampling rate
-        self.oversampling = int(oversampling)
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
-            self.channel = channel
+            self.channel = tf.keras.Sequential([
+                layers.Flatten(),
+                channel,
+                ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+            ])
         else:
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
 
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            layers.Input(shape=(self.neighbouring_blocks, self.cardinality)),
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(self.samples_per_symbol),
@@ -112,81 +174,171 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
-            ExtractCentralMessage(self.neighbouring_blocks, self.samples_per_symbol),
             layers.Dense(self.samples_per_symbol, activation='relu'),
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(self.cardinality, activation='softmax')
         ])
 
-    def generate_random_inputs(self, num_of_blocks):
-        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.neighbouring_blocks, 1))
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+        :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
+        consecutively to model ISI. The central message in a block is returned as the label for training.
+        :param return_vals: If true, the raw decimal values of the input sequence will be returned
+        """
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
-        # cat = np.reshape(np.arange(self.cardinality), (1, -1))
-        enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
+        cat = [np.arange(self.cardinality)]
+        enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
 
         out = enc.fit_transform(rand_int)
-        out_arr = np.reshape(out, (num_of_blocks, self.neighbouring_blocks, self.cardinality))
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
-        mid_idx = int((self.neighbouring_blocks-1)/2)
+        mid_idx = int((self.messages_per_block-1)/2)
 
-        return out_arr, out_arr[:, mid_idx, :]
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
 
+        return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, train_size=0.8):
+    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        """
+        :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
+        :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
+        :param train_size: Float less than 1 representing the proportion of the dataset to use for training
+        :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
+        """
         X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
         X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
 
-        self.compile(optimizer='adam',
+        opt = keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
                      loss=losses.BinaryCrossentropy(),
-                     metrics=None,
+                     metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
-                     run_eagerly=None
+                     run_eagerly=False
                      )
 
         self.fit(x=X_train,
                  y=y_train,
-                 batch_size=None,
+                 batch_size=batch_size,
                  epochs=1,
                  shuffle=True,
                  validation_data=(X_test, y_test)
                  )
 
-
-
     def view_encoder(self):
-        messages = np.zeros((self.cardinality, self.neighbouring_blocks, self.cardinality))
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
-        mid_idx = int((self.neighbouring_blocks-1)/2)
+        mid_idx = int((self.messages_per_block-1)/2)
 
         idx = 0
         for msg in messages:
             msg[mid_idx, idx] = 1
             idx += 1
 
+        # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
-        return messages, encoded[:, mid_idx, :]
+        enc_messages = encoded[:, mid_idx, :]
 
+        # Compute subplot grid layout
+        i = 0
+        while 2**i < self.cardinality**0.5:
+            i += 1
 
-    def call(self, x):
-        tx = self.encoder(x)
-        rx = self.channel(tx)
-        y = self.decoder(rx)
-        return y
+        num_x = int(2**i)
+        num_y = int(self.cardinality / num_x)
 
+        # Plot all symbols
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
 
-if __name__ == '__main__':
-    tx_channel = AwgnChannel(stddev=0.1)
+        t = np.arange(self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t/self.channel.layers[1].fs
+
+        sym_idx = 0
+        for y in range(num_y):
+            for x in range(num_x):
+                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
+                sym_idx += 1
+
+        for ax in axs.flat:
+            ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
+
+        for ax in axs.flat:
+            ax.label_outer()
+
+        plt.show()
+        pass
+
+    def view_sample_block(self):
+        # Generate a random block of messages
+        val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+        # Encode and flatten the messages
+        enc = self.encoder(inp)
+        flat_enc = layers.Flatten()(enc)
 
-    model = EndToEndAutoencoder(cardinality=8,
-                                samples_per_symbol=10,
-                                neighbouring_blocks=5,
-                                oversampling=4,
-                                channel=tx_channel)
+        # Instantiate LPF layer
+        lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
+                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
+                                q_stddev=0)
 
-    model.train()
+        # Apply LPF
+        lpf_out = lpf(flat_enc)
+
+        # Time axis
+        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t / self.channel.layers[1].fs
+
+        # Plot the concatenated symbols before and after LPF
+        plt.figure(figsize=(2*self.messages_per_block, 6))
+
+        for i in range(1, self.messages_per_block):
+            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+
+        plt.plot(t, flat_enc.numpy().T, 'x')
+        plt.plot(t, lpf_out.numpy().T)
+        plt.ylim((0, 1))
+        plt.xlim((t.min(), t.max()))
+        plt.title(str(val[0, :, 0]))
+        plt.show()
+        pass
+
+    def call(self, inputs, training=None, mask=None):
+        tx = self.encoder(inputs)
+        rx = self.channel(tx)
+        outputs = self.decoder(rx)
+        return outputs
+
+
+if __name__ == '__main__':
 
-    model.view_encoder()
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 24
+    MESSAGES_PER_BLOCK = 9
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH)
+
+    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                   messages_per_block=MESSAGES_PER_BLOCK,
+                                   channel=optical_channel)
+
+    ae_model.train(num_of_blocks=1e6, batch_size=100)
+    ae_model.view_encoder()
+    ae_model.view_sample_block()
 
     pass