Tharmetharan Balendran 4 лет назад
Родитель
Сommit
a75063d665
3 измененных файлов с 229 добавлено и 52 удалено
  1. 14 8
      models/custom_layers.py
  2. 95 44
      models/end_to_end.py
  3. 120 0
      models/new_model.py

+ 14 - 8
models/custom_layers.py

@@ -16,7 +16,8 @@ class BitsToSymbols(layers.Layer):
 
 
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
         idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
-        return tf.one_hot(idx, self.cardinality)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((9, 32))(out)
 
 
 
 
 class SymbolsToBits(layers.Layer):
 class SymbolsToBits(layers.Layer):
@@ -27,10 +28,10 @@ class SymbolsToBits(layers.Layer):
         lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
         lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
 
 
         # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
         # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
-        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32))
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
 
 
     def call(self, inputs, **kwargs):
     def call(self, inputs, **kwargs):
-        return tf.matmul(self.all_syms, inputs)
+        return tf.matmul(inputs, self.all_syms)
 
 
 
 
 class ExtractCentralMessage(layers.Layer):
 class ExtractCentralMessage(layers.Layer):
@@ -76,7 +77,8 @@ class DigitizationLayer(layers.Layer):
                  fs,
                  fs,
                  num_of_samples,
                  num_of_samples,
                  lpf_cutoff=32e9,
                  lpf_cutoff=32e9,
-                 q_stddev=0.1):
+                 sig_avg=0.5,
+                 enob=10):
         """
         """
         This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
         This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
         artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
         artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
@@ -88,7 +90,9 @@ class DigitizationLayer(layers.Layer):
         """
         """
         super(DigitizationLayer, self).__init__()
         super(DigitizationLayer, self).__init__()
 
 
-        self.noise_layer = layers.GaussianNoise(q_stddev)
+        stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
         freq = np.fft.fftfreq(num_of_samples, d=1/fs)
         freq = np.fft.fftfreq(num_of_samples, d=1/fs)
         temp = np.ones(freq.shape)
         temp = np.ones(freq.shape)
 
 
@@ -116,7 +120,8 @@ class OpticalChannel(layers.Layer):
                  fiber_length,
                  fiber_length,
                  lpf_cutoff=32e9,
                  lpf_cutoff=32e9,
                  rx_stddev=0.01,
                  rx_stddev=0.01,
-                 q_stddev=0.01):
+                 sig_avg=0.5,
+                 enob=10):
         """
         """
         A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
         A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
         ADC/DAC as well as additive white gaussian noise in optical communication channels.
         ADC/DAC as well as additive white gaussian noise in optical communication channels.
@@ -127,7 +132,7 @@ class OpticalChannel(layers.Layer):
         :param fiber_length: Length of fiber to model in km
         :param fiber_length: Length of fiber to model in km
         :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
         :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
         :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
         :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        :param sig_avg: Average signal amplitude
         """
         """
         super(OpticalChannel, self).__init__()
         super(OpticalChannel, self).__init__()
 
 
@@ -135,7 +140,8 @@ class OpticalChannel(layers.Layer):
         self.digitization_layer = DigitizationLayer(fs=fs,
         self.digitization_layer = DigitizationLayer(fs=fs,
                                                     num_of_samples=num_of_samples,
                                                     num_of_samples=num_of_samples,
                                                     lpf_cutoff=lpf_cutoff,
                                                     lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
+                                                    sig_avg=sig_avg,
+                                                    enob=enob)
         self.flatten_layer = layers.Flatten()
         self.flatten_layer = layers.Flatten()
 
 
         self.fs = fs
         self.fs = fs

+ 95 - 44
models/end_to_end.py

@@ -3,12 +3,14 @@ import math
 import tensorflow as tf
 import tensorflow as tf
 import numpy as np
 import numpy as np
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
 from tensorflow.keras import backend as K
 from tensorflow.keras import backend as K
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
 import itertools
 import itertools
 
 
+
 class EndToEndAutoencoder(tf.keras.Model):
 class EndToEndAutoencoder(tf.keras.Model):
     def __init__(self,
     def __init__(self,
                  cardinality,
                  cardinality,
@@ -53,27 +55,29 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Boolean identifying if bit mapping is to be learnt
         # Boolean identifying if bit mapping is to be learnt
         self.bit_mapping = bit_mapping
         self.bit_mapping = bit_mapping
 
 
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
+
         # Layer configuration for the case when bit mapping is to be learnt
         # Layer configuration for the case when bit mapping is to be learnt
         if self.bit_mapping:
         if self.bit_mapping:
             encoding_layers = [
             encoding_layers = [
                 layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
                 layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
                 BitsToSymbols(self.cardinality),
                 BitsToSymbols(self.cardinality),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
                 # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                # layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
                 # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
                 layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
                 layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+                layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
             ]
             ]
             decoding_layers = [
             decoding_layers = [
                 layers.Dense(2 * self.cardinality),
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
                 # layers.Dense(2 * self.cardinality),
                 # layers.Dense(2 * self.cardinality),
                 # layers.LeakyReLU(alpha=0.01),
                 # layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid')
             ]
             ]
 
 
         # layer configuration for the case when only symbol mapping is to be learnt
         # layer configuration for the case when only symbol mapping is to be learnt
@@ -81,21 +85,19 @@ class EndToEndAutoencoder(tf.keras.Model):
             encoding_layers = [
             encoding_layers = [
                 layers.Input(shape=(self.messages_per_block, self.cardinality)),
                 layers.Input(shape=(self.messages_per_block, self.cardinality)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
                 layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
-                layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
-                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
-                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
-                layers.TimeDistributed(layers.ReLU(max_value=1.0))
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
             ]
             ]
             decoding_layers = [
             decoding_layers = [
                 layers.Dense(2 * self.cardinality),
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
                 layers.Dense(2 * self.cardinality),
                 layers.Dense(2 * self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.cardinality),
-                layers.LeakyReLU(alpha=0.01),
-                layers.Dense(self.bits_per_symbol, activation='sigmoid'),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
+                layers.Dense(self.cardinality, activation='softmax')
             ]
             ]
 
 
         # Encoding Neural Network
         # Encoding Neural Network
@@ -130,7 +132,7 @@ class EndToEndAutoencoder(tf.keras.Model):
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
             out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
 
 
             if return_vals:
             if return_vals:
-                #TODO
+                return out_arr, out_arr, out_arr[:, mid_idx, :]
 
 
         else:
         else:
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
             rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
@@ -145,7 +147,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         return out_arr, out_arr[:, mid_idx, :]
         return out_arr, out_arr[:, mid_idx, :]
 
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-2):
+    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
         """
         """
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
         Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
 
 
@@ -159,8 +161,23 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
         opt = tf.keras.optimizers.Adam(learning_rate=lr)
 
 
+        # TODO: Investigate different optimizers (with different learning rates and other parameters)
+        # SGD
+        # RMSprop
+        # Adam
+        # Adadelta
+        # Adagrad
+        # Adamax
+        # Nadam
+        # Ftrl
+
+        if self.bit_mapping:
+            loss_fn = losses.BinaryCrossentropy()
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
+
         self.compile(optimizer=opt,
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=loss_fn,
                      metrics=['accuracy'],
                      metrics=['accuracy'],
                      loss_weights=None,
                      loss_weights=None,
                      weighted_metrics=None,
                      weighted_metrics=None,
@@ -168,33 +185,65 @@ class EndToEndAutoencoder(tf.keras.Model):
                      )
                      )
 
 
         history = self.fit(x=X_train,
         history = self.fit(x=X_train,
-                           y=y_train,
-                           batch_size=batch_size,
-                           epochs=1,
-                           shuffle=True,
-                           validation_data=(X_test, y_test)
-                           )
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
 
 
         plt.plot(history.history['accuracy'])
         plt.plot(history.history['accuracy'])
         plt.plot(history.history['val_accuracy'])
         plt.plot(history.history['val_accuracy'])
         plt.show()
         plt.show()
 
 
 
 
+    def test(self, num_of_blocks=1e3):
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+
+        y_out = self.call(X_test)
+
+        y_pred = tf.argmax(y_out, axis=1)
+        y_true = tf.argmax(y_test, axis=1)
+
+        symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+
+        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+
+        bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+        bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
+
+        bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+
+        print("SYMBOL ERROR RATE: {}".format(symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(bit_error_rate))
+
+        pass
+
     def view_encoder(self):
     def view_encoder(self):
         '''
         '''
         A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
         A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
         each message/symbol.
         each message/symbol.
         '''
         '''
 
 
-        # Generate inputs for encoder
-        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
-
         mid_idx = int((self.messages_per_block - 1) / 2)
         mid_idx = int((self.messages_per_block - 1) / 2)
 
 
-        idx = 0
-        for msg in messages:
-            msg[mid_idx, idx] = 1
-            idx += 1
+        if self.bit_mapping:
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
+            lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+
+            idx = 0
+            for msg in messages:
+                msg[mid_idx] = lst[idx]
+                idx += 1
+
+        else:
+            # Generate inputs for encoder
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+            idx = 0
+            for msg in messages:
+                msg[mid_idx, idx] = 1
+                idx += 1
 
 
         # Pass input through encoder and select middle messages
         # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
         encoded = self.encoder(messages)
@@ -218,7 +267,7 @@ class EndToEndAutoencoder(tf.keras.Model):
         sym_idx = 0
         sym_idx = 0
         for y in range(num_y):
         for y in range(num_y):
             for x in range(num_x):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 sym_idx += 1
                 sym_idx += 1
 
 
@@ -242,11 +291,12 @@ class EndToEndAutoencoder(tf.keras.Model):
         # Encode and flatten the messages
         # Encode and flatten the messages
         enc = self.encoder(inp)
         enc = self.encoder(inp)
         flat_enc = layers.Flatten()(enc)
         flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
 
 
         # Instantiate LPF layer
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
                                 num_of_samples=self.messages_per_block * self.samples_per_symbol,
                                 num_of_samples=self.messages_per_block * self.samples_per_symbol,
-                                q_stddev=0)
+                                sig_avg=0)
 
 
         # Apply LPF
         # Apply LPF
         lpf_out = lpf(flat_enc)
         lpf_out = lpf(flat_enc)
@@ -264,6 +314,7 @@ class EndToEndAutoencoder(tf.keras.Model):
 
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
         plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
         plt.ylim((0, 1))
         plt.ylim((0, 1))
         plt.xlim((t.min(), t.max()))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
         plt.title(str(val[0, :, 0]))
@@ -276,16 +327,15 @@ class EndToEndAutoencoder(tf.keras.Model):
         outputs = self.decoder(rx)
         outputs = self.decoder(rx)
         return outputs
         return outputs
 
 
+
 SAMPLING_FREQUENCY = 336e9
 SAMPLING_FREQUENCY = 336e9
 CARDINALITY = 32
 CARDINALITY = 32
-SAMPLES_PER_SYMBOL = 24
+SAMPLES_PER_SYMBOL = 32
 MESSAGES_PER_BLOCK = 9
 MESSAGES_PER_BLOCK = 9
 DISPERSION_FACTOR = -21.7 * 1e-24
 DISPERSION_FACTOR = -21.7 * 1e-24
-FIBER_LENGTH = 50
-
+FIBER_LENGTH = 0
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
                                      num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      dispersion_factor=DISPERSION_FACTOR,
                                      dispersion_factor=DISPERSION_FACTOR,
@@ -295,11 +345,12 @@ if __name__ == '__main__':
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    messages_per_block=MESSAGES_PER_BLOCK,
                                    channel=optical_channel,
                                    channel=optical_channel,
-                                   bit_mapping=True)
+                                   bit_mapping=False)
 
 
-    ae_model.train(num_of_blocks=1e6)
-    # ae_model.view_encoder()
-    # ae_model.view_sample_block()
+    ae_model.train(num_of_blocks=1e5, epochs=5)
+    ae_model.test()
+    ae_model.view_encoder()
+    ae_model.view_sample_block()
     # ae_model.summary()
     # ae_model.summary()
     ae_model.encoder.summary()
     ae_model.encoder.summary()
     ae_model.channel.summary()
     ae_model.channel.summary()

+ 120 - 0
models/new_model.py

@@ -0,0 +1,120 @@
+import tensorflow as tf
+from tensorflow.keras import layers, losses
+from models.custom_layers import ExtractCentralMessage, OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.custom_layers import BitsToSymbols, SymbolsToBits
+import numpy as np
+import math
+
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             bit_mapping=False)
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        """
+
+        """
+        for _ in range(iters):
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+            self.e2e_model.test()
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+    def test(self, num_of_blocks=1e4):
+        pass
+
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 24
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+
+if __name__ == '__main__':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    # a , c = model.generate_random_inputs(num_of_blocks=1)
+    #
+    # a = tf.convert_to_tensor(a, dtype=tf.float32)
+    # b = model(a)
+
+    model.train(iters=1, num_of_blocks=1e4, epochs=1)
+    model.summary()