Bläddra i källkod

refactoring and lstm layers added

Tharmetharan Balendran 5 år sedan
förälder
incheckning
80a5361d42
2 ändrade filer med 202 tillägg och 151 borttagningar
  1. 140 0
      models/custom_layers.py
  2. 62 151
      models/end_to_end.py

+ 140 - 0
models/custom_layers.py

@@ -0,0 +1,140 @@
+from tensorflow.keras import layers
+import tensorflow as tf
+import math
+import numpy as np
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.samples_per_symbol = samples_per_symbol
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        out = tf.matmul(inputs, self.w)
+        return tf.reshape(out, shape=(1, 1, self.samples_per_symbol))
+        # TODO: this won't work with dense layers need to move to separate layer
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1):
+        """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__()
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 q_stddev=0.1):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(q_stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 q_stddev=0.01):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.digitization_layer = DigitizationLayer(fs=fs,
+                                                    num_of_samples=num_of_samples,
+                                                    lpf_cutoff=lpf_cutoff,
+                                                    q_stddev=q_stddev)
+        self.flatten_layer = layers.Flatten()
+
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64)
+        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out

+ 62 - 151
models/end_to_end.py

@@ -5,139 +5,7 @@ import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
-
-
-class ExtractCentralMessage(layers.Layer):
-    def __init__(self, messages_per_block, samples_per_symbol):
-        """
-        A keras layer that extracts the central message(symbol) in a block.
-
-        :param messages_per_block: Total number of messages in transmission block
-        :param samples_per_symbol: Number of samples per transmitted symbol
-        """
-        super(ExtractCentralMessage, self).__init__()
-
-        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
-        i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
-        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
-        temp_w[begin:end, :] = i
-
-        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(inputs, self.w)
-
-
-class AwgnChannel(layers.Layer):
-    def __init__(self, rx_stddev=0.1):
-        """
-        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
-        being applied every time the call function is called.
-
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        """
-        super(AwgnChannel, self).__init__()
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-
-    def call(self, inputs, **kwargs):
-        return self.noise_layer.call(inputs, training=True)
-
-
-class DigitizationLayer(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 lpf_cutoff=32e9,
-                 q_stddev=0.1):
-        """
-        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
-        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(DigitizationLayer, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(q_stddev)
-        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
-        temp = np.ones(freq.shape)
-
-        for idx, val in np.ndenumerate(freq):
-            if np.abs(val) > lpf_cutoff:
-                temp[idx] = 0
-
-        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
-
-    def call(self, inputs, **kwargs):
-        complex_in = tf.cast(inputs, dtype=tf.complex64)
-        val_f = tf.signal.fft(complex_in)
-        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
-        filtered_t = tf.signal.ifft(filtered_f)
-        real_t = tf.cast(filtered_t, dtype=tf.float32)
-        noisy = self.noise_layer.call(real_t, training=True)
-        return noisy
-
-
-class OpticalChannel(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 dispersion_factor,
-                 fiber_length,
-                 lpf_cutoff=32e9,
-                 rx_stddev=0.01,
-                 q_stddev=0.01):
-        """
-        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
-        ADC/DAC as well as additive white gaussian noise in optical communication channels.
-
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param dispersion_factor: Dispersion factor in s^2/km
-        :param fiber_length: Length of fiber to model in km
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(OpticalChannel, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
-        self.flatten_layer = layers.Flatten()
-
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
-
-    def call(self, inputs, **kwargs):
-        # DAC LPF and noise
-        dac_out = self.digitization_layer(inputs)
-
-        # Chromatic Dispersion
-        complex_val = tf.cast(dac_out, dtype=tf.complex128)
-        val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
-        disp_t = tf.signal.ifft(disp_f)
-
-        # Squared-Law Detection
-        pd_out = tf.square(tf.abs(disp_t))
-
-        # Casting back to floatx
-        real_val = tf.cast(pd_out, dtype=tf.float32)
-
-        # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(real_val, training=True)
-
-        # ADC LPF and noise
-        adc_out = self.digitization_layer(rx_signal)
-
-        return adc_out
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer
 
 
 class EndToEndAutoencoder(tf.keras.Model):
@@ -145,7 +13,8 @@ class EndToEndAutoencoder(tf.keras.Model):
                  cardinality,
                  samples_per_symbol,
                  messages_per_block,
-                 channel):
+                 channel,
+                 recurrent=False):
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -172,26 +41,43 @@ class EndToEndAutoencoder(tf.keras.Model):
                 layers.Flatten(),
                 channel,
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
-            ])
+            ], name="channel_model")
         else:
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
+        self.recurrent = recurrent
+
+        if recurrent:
+            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality), batch_size=1)
+            # encoding_layers = [
+            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
+            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
+            # ]
+            decoding_layers = [
+                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
+                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
+            ]
+        else:
+            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality))
+            decoding_layers = [
+                layers.Dense(2 * self.cardinality, activation='relu'),
+                layers.Dense(2 * self.cardinality, activation='relu')
+            ]
 
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            input_layer,
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(2 * self.cardinality, activation='relu'),
             layers.Dense(self.samples_per_symbol),
             layers.ReLU(max_value=1.0)
-        ])
+        ], name="encoding_model")
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
             layers.Dense(self.samples_per_symbol, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
+            *decoding_layers,
             layers.Dense(self.cardinality, activation='softmax')
-        ])
+        ], name="decoding_model")
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
@@ -201,21 +87,37 @@ class EndToEndAutoencoder(tf.keras.Model):
         consecutively to model ISI. The central message in a block is returned as the label for training.
         :param return_vals: If true, the raw decimal values of the input sequence will be returned
         """
-        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
         cat = [np.arange(self.cardinality)]
         enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
 
-        out = enc.fit_transform(rand_int)
-        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+        mid_idx = int((self.messages_per_block - 1) / 2)
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        if self.recurrent and not return_vals:
+            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks+self.messages_per_block-1, 1))
+
+            rand_enc = enc.fit_transform(rand_int)
+
+            out = []
+
+            for i in range(num_of_blocks):
+                out.append(rand_enc[i:i+self.messages_per_block])
 
-        if return_vals:
-            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
-            return out_val, out_arr, out_arr[:, mid_idx, :]
+            out = np.array(out)
 
-        return out_arr, out_arr[:, mid_idx, :]
+            return out, out[:, mid_idx, :]
+
+        else:
+            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
+            out = enc.fit_transform(rand_int)
+            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+
+            if return_vals:
+                out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+                return out_val, out_arr, out_arr[:, mid_idx, :]
+
+            return out_arr, out_arr[:, mid_idx, :]
 
     def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
         """
@@ -239,11 +141,18 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      )
 
+        shuffle = True
+        if self.recurrent and batch_size is None:
+            # If recurrent layers are present in the model then the training data is considered one at a time without
+            # shuffling of the data. This preserves order in the data.
+            batch_size = 1
+            shuffle = False
+
         self.fit(x=X_train,
                  y=y_train,
                  batch_size=batch_size,
                  epochs=1,
-                 shuffle=True,
+                 shuffle=shuffle,
                  validation_data=(X_test, y_test)
                  )
 
@@ -360,10 +269,12 @@ if __name__ == '__main__':
     ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
-                                   channel=optical_channel)
+                                   channel=optical_channel,
+                                   recurrent=True)
 
-    ae_model.train(num_of_blocks=1e6, batch_size=100)
+    ae_model.train(num_of_blocks=1e5)
     ae_model.view_encoder()
     ae_model.view_sample_block()
+    ae_model.summary()
 
     pass