Преглед на файлове

bit-symbol mapping attemts

Tharmetharan Balendran преди 4 години
родител
ревизия
7e2c83ee4c
променени са 3 файла, в които са добавени 202 реда и са изтрити 83 реда
  1. 30 3
      models/custom_layers.py
  2. 105 78
      models/end_to_end.py
  3. 67 2
      tests/misc_test.py

+ 30 - 3
models/custom_layers.py

@@ -2,6 +2,35 @@ from tensorflow.keras import layers
 import tensorflow as tf
 import tensorflow as tf
 import math
 import math
 import numpy as np
 import numpy as np
+import itertools
+
+
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        return tf.one_hot(idx, self.cardinality)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32))
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(self.all_syms, inputs)
 
 
 
 
 class ExtractCentralMessage(layers.Layer):
 class ExtractCentralMessage(layers.Layer):
@@ -24,9 +53,7 @@ class ExtractCentralMessage(layers.Layer):
         self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
         self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
 
 
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
-        out = tf.matmul(inputs, self.w)
-        return tf.reshape(out, shape=(1, 1, self.samples_per_symbol))
-        # TODO: this won't work with dense layers need to move to separate layer
+        return tf.matmul(inputs, self.w)
 
 
 
 
 class AwgnChannel(layers.Layer):
 class AwgnChannel(layers.Layer):

+ 105 - 78
models/end_to_end.py

@@ -5,8 +5,9 @@ import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer
-
+from tensorflow.keras import backend as K
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols
+import itertools
 
 
 class EndToEndAutoencoder(tf.keras.Model):
 class EndToEndAutoencoder(tf.keras.Model):
     def __init__(self,
     def __init__(self,
@@ -14,7 +15,7 @@ class EndToEndAutoencoder(tf.keras.Model):
                  samples_per_symbol,
                  samples_per_symbol,
                  messages_per_block,
                  messages_per_block,
                  channel,
                  channel,
-                 recurrent=False):
+                 bit_mapping=False):
         """
         """
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
         of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
@@ -29,12 +30,16 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         # Labelled M in paper
         # Labelled M in paper
         self.cardinality = cardinality
         self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
         # Labelled n in paper
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
         self.samples_per_symbol = samples_per_symbol
+
         # Labelled N in paper
         # Labelled N in paper
         if messages_per_block % 2 == 0:
         if messages_per_block % 2 == 0:
             messages_per_block += 1
             messages_per_block += 1
         self.messages_per_block = messages_per_block
         self.messages_per_block = messages_per_block
+
         # Channel Model Layer
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
         if isinstance(channel, layers.Layer):
             self.channel = tf.keras.Sequential([
             self.channel = tf.keras.Sequential([
@@ -44,39 +49,63 @@ class EndToEndAutoencoder(tf.keras.Model):
             ], name="channel_model")
             ], name="channel_model")
         else:
         else:
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
-        self.recurrent = recurrent
-
-        if recurrent:
-            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality), batch_size=1)
-            # encoding_layers = [
-            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
-            #     layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
-            # ]
+
+        # Boolean identifying if bit mapping is to be learnt
+        self.bit_mapping = bit_mapping
+
+        # Layer configuration for the case when bit mapping is to be learnt
+        if self.bit_mapping:
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
+                BitsToSymbols(self.cardinality),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                # layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+            ]
             decoding_layers = [
             decoding_layers = [
-                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
-                layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                # layers.Dense(2 * self.cardinality),
+                # layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
             ]
             ]
+
+        # layer configuration for the case when only symbol mapping is to be learnt
         else:
         else:
-            input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality))
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.cardinality)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+            ]
             decoding_layers = [
             decoding_layers = [
-                layers.Dense(2 * self.cardinality, activation='relu'),
-                layers.Dense(2 * self.cardinality, activation='relu')
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.cardinality),
+                layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
             ]
             ]
 
 
         # Encoding Neural Network
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
         self.encoder = tf.keras.Sequential([
-            input_layer,
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.samples_per_symbol),
-            layers.ReLU(max_value=1.0)
+            *encoding_layers
         ], name="encoding_model")
         ], name="encoding_model")
 
 
         # Decoding Neural Network
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
         self.decoder = tf.keras.Sequential([
-            layers.Dense(self.samples_per_symbol, activation='relu'),
-            *decoding_layers,
-            layers.Dense(self.cardinality, activation='softmax')
+            *decoding_layers
         ], name="decoding_model")
         ], name="decoding_model")
 
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
@@ -93,33 +122,30 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         mid_idx = int((self.messages_per_block - 1) / 2)
         mid_idx = int((self.messages_per_block - 1) / 2)
 
 
-        if self.recurrent and not return_vals:
-            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks+self.messages_per_block-1, 1))
-
-            rand_enc = enc.fit_transform(rand_int)
+        if self.bit_mapping:
+            rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
 
 
-            out = []
+            out = rand_int
 
 
-            for i in range(num_of_blocks):
-                out.append(rand_enc[i:i+self.messages_per_block])
+            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
 
 
-            out = np.array(out)
-
-            return out, out[:, mid_idx, :]
+            if return_vals:
+                #TODO
 
 
         else:
         else:
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
 
             out = enc.fit_transform(rand_int)
             out = enc.fit_transform(rand_int)
+
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
 
 
             if return_vals:
             if return_vals:
                 out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
                 out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
                 return out_val, out_arr, out_arr[:, mid_idx, :]
                 return out_val, out_arr, out_arr[:, mid_idx, :]
 
 
-            return out_arr, out_arr[:, mid_idx, :]
+        return out_arr, out_arr[:, mid_idx, :]
 
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-2):
         """
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
 
@@ -128,8 +154,8 @@ class EndToEndAutoencoder(tf.keras.Model):
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         """
         """
-        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
-        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
 
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
 
@@ -141,30 +167,29 @@ class EndToEndAutoencoder(tf.keras.Model):
                      run_eagerly=False
                      run_eagerly=False
                      )
                      )
 
 
-        shuffle = True
-        if self.recurrent and batch_size is None:
-            # If recurrent layers are present in the model then the training data is considered one at a time without
-            # shuffling of the data. This preserves order in the data.
-            batch_size = 1
-            shuffle = False
-
-        self.fit(x=X_train,
-                 y=y_train,
-                 batch_size=batch_size,
-                 epochs=1,
-                 shuffle=shuffle,
-                 validation_data=(X_test, y_test)
-                 )
+        history = self.fit(x=X_train,
+                           y=y_train,
+                           batch_size=batch_size,
+                           epochs=1,
+                           shuffle=True,
+                           validation_data=(X_test, y_test)
+                           )
+
+        plt.plot(history.history['accuracy'])
+        plt.plot(history.history['val_accuracy'])
+        plt.show()
+
 
 
     def view_encoder(self):
     def view_encoder(self):
         '''
         '''
-        A method that views the learnt encoder for each distint message. This is displayed as a plot with  asubplot for
-        each image.
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
         '''
         '''
+
         # Generate inputs for encoder
         # Generate inputs for encoder
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
         messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
 
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        mid_idx = int((self.messages_per_block - 1) / 2)
 
 
         idx = 0
         idx = 0
         for msg in messages:
         for msg in messages:
@@ -177,18 +202,18 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         # Compute subplot grid layout
         # Compute subplot grid layout
         i = 0
         i = 0
-        while 2**i < self.cardinality**0.5:
+        while 2 ** i < self.cardinality ** 0.5:
             i += 1
             i += 1
 
 
-        num_x = int(2**i)
+        num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
         num_y = int(self.cardinality / num_x)
 
 
         # Plot all symbols
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
 
 
         t = np.arange(self.samples_per_symbol)
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
         if isinstance(self.channel.layers[1], OpticalChannel):
-            t = t/self.channel.layers[1].fs
+            t = t / self.channel.layers[1].fs
 
 
         sym_idx = 0
         sym_idx = 0
         for y in range(num_y):
         for y in range(num_y):
@@ -220,22 +245,22 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         # Instantiate LPF layer
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
-                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
                                 q_stddev=0)
                                 q_stddev=0)
 
 
         # Apply LPF
         # Apply LPF
         lpf_out = lpf(flat_enc)
         lpf_out = lpf(flat_enc)
 
 
         # Time axis
         # Time axis
-        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
         if isinstance(self.channel.layers[1], OpticalChannel):
             t = t / self.channel.layers[1].fs
             t = t / self.channel.layers[1].fs
 
 
         # Plot the concatenated symbols before and after LPF
         # Plot the concatenated symbols before and after LPF
-        plt.figure(figsize=(2*self.messages_per_block, 6))
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
 
 
         for i in range(1, self.messages_per_block):
         for i in range(1, self.messages_per_block):
-            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
 
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
         plt.plot(t, lpf_out.numpy().T)
@@ -251,18 +276,18 @@ class EndToEndAutoencoder(tf.keras.Model):
         outputs = self.decoder(rx)
         outputs = self.decoder(rx)
         return outputs
         return outputs
 
 
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 24
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
 
 
-if __name__ == '__main__':
 
 
-    SAMPLING_FREQUENCY = 336e9
-    CARDINALITY = 32
-    SAMPLES_PER_SYMBOL = 24
-    MESSAGES_PER_BLOCK = 9
-    DISPERSION_FACTOR = -21.7 * 1e-24
-    FIBER_LENGTH = 50
+if __name__ == '__main__':
 
 
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      dispersion_factor=DISPERSION_FACTOR,
                                      dispersion_factor=DISPERSION_FACTOR,
                                      fiber_length=FIBER_LENGTH)
                                      fiber_length=FIBER_LENGTH)
 
 
@@ -270,11 +295,13 @@ if __name__ == '__main__':
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    channel=optical_channel,
                                    channel=optical_channel,
-                                   recurrent=True)
-
-    ae_model.train(num_of_blocks=1e5)
-    ae_model.view_encoder()
-    ae_model.view_sample_block()
-    ae_model.summary()
-
+                                   bit_mapping=True)
+
+    ae_model.train(num_of_blocks=1e6)
+    # ae_model.view_encoder()
+    # ae_model.view_sample_block()
+    # ae_model.summary()
+    ae_model.encoder.summary()
+    ae_model.channel.summary()
+    ae_model.decoder.summary()
     pass
     pass

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import misc
 import numpy as np
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 
 
 def test_bit_matrix_one_hot():
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass