Pārlūkot izejas kodu

bit-symbol mapping attemts

Tharmetharan Balendran 4 gadi atpakaļ
vecāks
revīzija
7e2c83ee4c
3 mainītis faili ar 202 papildinājumiem un 83 dzēšanām
  1. 30 3
      models/custom_layers.py
  2. 105 78
      models/end_to_end.py
  3. 67 2
      tests/misc_test.py

+ 30 - 3
models/custom_layers.py

@@ -2,6 +2,35 @@ from tensorflow.keras import layers
 import tensorflow as tf
 import math
 import numpy as np
+import itertools
+
+
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        return tf.one_hot(idx, self.cardinality)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32))
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(self.all_syms, inputs)
 
 
 class ExtractCentralMessage(layers.Layer):
@@ -24,9 +53,7 @@ class ExtractCentralMessage(layers.Layer):
         self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
 
     def call(self, inputs, **kwargs):
-        out = tf.matmul(inputs, self.w)
-        return tf.reshape(out, shape=(1, 1, self.samples_per_symbol))
-        # TODO: this won't work with dense layers need to move to separate layer
+        return tf.matmul(inputs, self.w)
 
 
 class AwgnChannel(layers.Layer):

+ 105 - 78
models/end_to_end.py

@@ -5,8 +5,9 @@ import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer
-
+from tensorflow.keras import backend as K
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols
+import itertools
 
 class EndToEndAutoencoder(tf.keras.Model):
     def __init__(self,
@@ -14,7 +15,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  samples_per_symbol,
                  messages_per_block,
                  channel,
-                 recurrent=False):
+                 bit_mapping=False):
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -29,12 +30,16 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Labelled M in paper
         self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
+
         # Labelled N in paper
         if messages_per_block % 2 == 0:
             messages_per_block += 1
         self.messages_per_block = messages_per_block
+
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
             self.channel = tf.keras.Sequential([
@@ -44,39 +49,63 @@ class EndToEndAutoencoder(tf.keras.Model):
             ], name="channel_model")
         else:
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
-        self.recurrent = recurrent
-
-        if recurrent:
-            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality), batch_size=1)
-            # encoding_layers = [
-            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
-            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
-            # ]
+
+        # Boolean identifying if bit mapping is to be learnt
+        self.bit_mapping = bit_mapping
+
+        # Layer configuration for the case when bit mapping is to be learnt
+        if self.bit_mapping:
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
+                BitsToSymbols(self.cardinality),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                # layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+            ]
             decoding_layers = [
-                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
-                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                # layers.Dense(2 * self.cardinality),
+                # layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
             ]
+
+        # layer configuration for the case when only symbol mapping is to be learnt
         else:
-            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality))
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.cardinality)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+            ]
             decoding_layers = [
-                layers.Dense(2 * self.cardinality, activation='relu'),
-                layers.Dense(2 * self.cardinality, activation='relu')
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
             ]
 
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            input_layer,
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.samples_per_symbol),
-            layers.ReLU(max_value=1.0)
+            *encoding_layers
         ], name="encoding_model")
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
-            layers.Dense(self.samples_per_symbol, activation='relu'),
-            *decoding_layers,
-            layers.Dense(self.cardinality, activation='softmax')
+            *decoding_layers
         ], name="decoding_model")
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
@@ -93,33 +122,30 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         mid_idx = int((self.messages_per_block - 1) / 2)
 
-        if self.recurrent and not return_vals:
-            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks+self.messages_per_block-1, 1))
-
-            rand_enc = enc.fit_transform(rand_int)
+        if self.bit_mapping:
+            rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
 
-            out = []
+            out = rand_int
 
-            for i in range(num_of_blocks):
-                out.append(rand_enc[i:i+self.messages_per_block])
+            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
 
-            out = np.array(out)
-
-            return out, out[:, mid_idx, :]
+            if return_vals:
+                #TODO
 
         else:
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
             out = enc.fit_transform(rand_int)
+
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
             if return_vals:
                 out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
                 return out_val, out_arr, out_arr[:, mid_idx, :]
 
-            return out_arr, out_arr[:, mid_idx, :]
+        return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-2):
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
@@ -128,8 +154,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         """
-        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
-        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
@@ -141,30 +167,29 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      )
 
-        shuffle = True
-        if self.recurrent and batch_size is None:
-            # If recurrent layers are present in the model then the training data is considered one at a time without
-            # shuffling of the data. This preserves order in the data.
-            batch_size = 1
-            shuffle = False
-
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=1,
-                 shuffle=shuffle,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=1,
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        plt.plot(history.history['accuracy'])
+        plt.plot(history.history['val_accuracy'])
+        plt.show()
+
 
     def view_encoder(self):
         '''
-        A method that views the learnt encoder for each distint message. This is displayed as a plot with  asubplot for
-        each image.
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
         '''
+
         # Generate inputs for encoder
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        mid_idx = int((self.messages_per_block - 1) / 2)
 
         idx = 0
         for msg in messages:
@@ -177,18 +202,18 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Compute subplot grid layout
         i = 0
-        while 2**i < self.cardinality**0.5:
+        while 2 ** i < self.cardinality ** 0.5:
             i += 1
 
-        num_x = int(2**i)
+        num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
 
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
 
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
-            t = t/self.channel.layers[1].fs
+            t = t / self.channel.layers[1].fs
 
         sym_idx = 0
         for y in range(num_y):
@@ -220,22 +245,22 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
-                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
                                 q_stddev=0)
 
         # Apply LPF
         lpf_out = lpf(flat_enc)
 
         # Time axis
-        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
             t = t / self.channel.layers[1].fs
 
         # Plot the concatenated symbols before and after LPF
-        plt.figure(figsize=(2*self.messages_per_block, 6))
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
 
         for i in range(1, self.messages_per_block):
-            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
@@ -251,18 +276,18 @@ class EndToEndAutoencoder(tf.keras.Model):
         outputs = self.decoder(rx)
         return outputs
 
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 24
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
 
-if __name__ == '__main__':
 
-    SAMPLING_FREQUENCY = 336e9
-    CARDINALITY = 32
-    SAMPLES_PER_SYMBOL = 24
-    MESSAGES_PER_BLOCK = 9
-    DISPERSION_FACTOR = -21.7 * 1e-24
-    FIBER_LENGTH = 50
+if __name__ == '__main__':
 
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      dispersion_factor=DISPERSION_FACTOR,
                                      fiber_length=FIBER_LENGTH)
 
@@ -270,11 +295,13 @@ if __name__ == '__main__':
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    channel=optical_channel,
-                                   recurrent=True)
-
-    ae_model.train(num_of_blocks=1e5)
-    ae_model.view_encoder()
-    ae_model.view_sample_block()
-    ae_model.summary()
-
+                                   bit_mapping=True)
+
+    ae_model.train(num_of_blocks=1e6)
+    # ae_model.view_encoder()
+    # ae_model.view_sample_block()
+    # ae_model.summary()
+    ae_model.encoder.summary()
+    ae_model.channel.summary()
+    ae_model.decoder.summary()
     pass

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass