Parcourir la source

Merge branch 'temp_e2e' into Standalone_NN_devel

Min il y a 4 ans
Parent
commit
25c4fc2a23
12 fichiers modifiés avec 1343 ajouts et 242 suppressions
  1. 21 0
      graphs.py
  2. 9 1
      misc.py
  3. 32 7
      models/autoencoder.py
  4. 518 0
      models/binary_net.py
  5. 22 0
      models/data.py
  6. 194 168
      models/end_to_end.py
  7. 41 0
      models/gray_code.py
  8. 155 18
      models/layers.py
  9. 175 0
      models/new_model.py
  10. 6 4
      models/quantized_net.py
  11. 103 42
      tests/min_test.py
  12. 67 2
      tests/misc_test.py

+ 21 - 0
graphs.py

@@ -86,3 +86,24 @@ def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=1
     ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
     SNR = -ber_x + av_sig_pow
     return SNR, ber_y
+
+
+def show_train_history(history, title="", save=None):
+    from matplotlib import pyplot as plt
+
+    epochs = range(1, len(history.epoch) + 1)
+    if 'loss' in history.history:
+        plt.plot(epochs, history.history['loss'], label='Training Loss')
+    if 'accuracy' in history.history:
+        plt.plot(epochs, history.history['accuracy'], label='Training Accuracy')
+    if 'val_loss' in history.history:
+        plt.plot(epochs, history.history['val_loss'], label='Validation Loss')
+    if 'val_accuracy' in history.history:
+        plt.plot(epochs, history.history['val_accuracy'], label='Validation Accuracy')
+    plt.xlabel('Epochs')
+    plt.ylabel('Loss/Accuracy' if 'accuracy' in history.history else 'Loss')
+    plt.legend()
+    plt.title(title)
+    if save is not None:
+        plt.savefig(save)
+    plt.show()

+ 9 - 1
misc.py

@@ -2,7 +2,7 @@
 import numpy as np
 import math
 import matplotlib.pyplot as plt
-
+import pickle
 
 def display_alphabet(alphabet, values=None, a_vals=False, title="Alphabet constellation diagram"):
     rect = polar2rect(alphabet)
@@ -105,3 +105,11 @@ def generate_random_bit_array(size):
     return arr
 
 
+def picke_save(obj, fname):
+    with open(fname, 'wb') as f:
+        pickle.dump(obj, f)
+
+
+def picke_load(fname):
+    with open(fname, 'rb') as f:
+        return pickle.load(f)

+ 32 - 7
models/autoencoder.py

@@ -67,7 +67,16 @@ class AutoencoderDemod(defs.Demodulator):
 
 
 class Autoencoder(Model):
-    def __init__(self, N, channel, signal_dim=2, parallel=1, all_onehot=True, bipolar=True, encoder=None, decoder=None):
+    def __init__(self, N, channel,
+                 signal_dim=2,
+                 parallel=1,
+                 all_onehot=True,
+                 bipolar=True,
+                 encoder=None,
+                 decoder=None,
+                 data_generator=None,
+                 cost=None
+                 ):
         super(Autoencoder, self).__init__()
         self.N = N
         self.parallel = parallel
@@ -116,6 +125,13 @@ class Autoencoder(Model):
                 raise ValueError("Channel is not a keras layer")
             self.channel.add(channel)
 
+        self.data_generator = data_generator
+        if data_generator is None:
+            self.data_generator = BinaryOneHotGenerator
+
+        self.cost = cost
+        if cost is None:
+            self.cost = losses.MeanSquaredError()
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
 
         # [
@@ -181,10 +197,10 @@ class Autoencoder(Model):
 
         print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
 
-    def train(self, epoch_size=3e3, epochs=5):
+    def train(self, epoch_size=3e3, epochs=5, callbacks=None, optimizer='adam', metrics=None):
         m = self.N * self.parallel
-        x_train = BinaryOneHotGenerator(size=epoch_size, shape=m)
-        x_test = BinaryOneHotGenerator(size=epoch_size*.3, shape=m)
+        x_train = self.data_generator(size=epoch_size, shape=m)
+        x_test = self.data_generator(size=epoch_size*.3, shape=m)
 
         # test_samples = epoch_size
         # if test_samples % m:
@@ -194,12 +210,21 @@ class Autoencoder(Model):
         # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
         if not self.compiled:
-            self.compile(optimizer='adam', loss=losses.MeanSquaredError())
+            self.compile(
+                optimizer=optimizer,
+                loss=self.cost,
+                metrics=metrics
+            )
             self.compiled = True
             # self.build((self._input_shape, -1))
             # self.summary()
 
-        self.fit(x_train, shuffle=False, validation_data=x_test, epochs=epochs)
+        history = self.fit(
+            x_train, shuffle=False,
+            validation_data=x_test, epochs=epochs,
+            callbacks=callbacks,
+        )
+        return history
         # encoded_data = self.encoder(x_test_ho)
         # decoded_data = self.decoder(encoded_data).numpy()
 
@@ -215,7 +240,7 @@ class Autoencoder(Model):
 
 
 def view_encoder(encoder, N, samples=1000, title="Autoencoder generated alphabet"):
-    test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
+    test_values = misc.generate_random_bit_array(samples*N).reshape((-1, N))
     test_values_ho = misc.bit_matrix2one_hot(test_values)
     mvector = np.array([2 ** i for i in range(N)], dtype=int)
     symbols = (test_values * mvector).sum(axis=1)

+ 518 - 0
models/binary_net.py

@@ -0,0 +1,518 @@
+"""
+Adopted from https://github.com/uranusx86/BinaryNet-on-tensorflow
+
+"""
+
+# coding=UTF-8
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.python.framework import tensor_shape, ops
+from tensorflow.python.ops import standard_ops, nn, variable_scope, math_ops, control_flow_ops
+from tensorflow.python.eager import context
+from tensorflow.python.training import optimizer, training_ops
+import numpy as np
+
+# Warning: if you have a @property getter/setter function in a class, must inherit from object class
+
+all_layers = []
+
+
+def hard_sigmoid(x):
+    return tf.clip_by_value((x + 1.) / 2., 0, 1)
+
+
+def round_through(x):
+    """
+    Element-wise rounding to the closest integer with full gradient propagation.
+    A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
+    a op that behave as f(x) in forward mode,
+    but as g(x) in the backward mode.
+    """
+    rounded = tf.round(x)
+    return x + tf.stop_gradient(rounded - x)
+
+
+# The neurons' activations binarization function
+# It behaves like the sign function during forward propagation
+# And like:
+#   hard_tanh(x) = 2*hard_sigmoid(x)-1
+# during back propagation
+def binary_tanh_unit(x):
+    return 2. * round_through(hard_sigmoid(x)) - 1.
+
+
+def binary_sigmoid_unit(x):
+    return round_through(hard_sigmoid(x))
+
+
+# The weights' binarization function,
+# taken directly from the BinaryConnect github repository
+# (which was made available by his authors)
+def binarization(W, H, binary=True, deterministic=False, stochastic=False, srng=None):
+    dim = W.get_shape().as_list()
+
+    # (deterministic == True) <-> test-time <-> inference-time
+    if not binary or (deterministic and stochastic):
+        # print("not binary")
+        Wb = W
+
+    else:
+        # [-1,1] -> [0,1]
+        # Wb = hard_sigmoid(W/H)
+        # Wb = T.clip(W/H,-1,1)
+
+        # Stochastic BinaryConnect
+        '''
+        if stochastic:
+            # print("stoch")
+            Wb = tf.cast(srng.binomial(n=1, p=Wb, size=tf.shape(Wb)), tf.float32)
+        '''
+
+        # Deterministic BinaryConnect (round to nearest)
+        # else:
+        # print("det")
+        # Wb = tf.round(Wb)
+
+        # 0 or 1 -> -1 or 1
+        # Wb = tf.where(tf.equal(Wb, 1.0), tf.ones_like(W), -tf.ones_like(W))  # cant differential
+        Wb = H * binary_tanh_unit(W / H)
+
+    return Wb
+
+
+class DenseBinaryLayer(keras.layers.Dense):
+    def __init__(self, output_dim,
+                 activation=None,
+                 use_bias=True,
+                 binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
+                 kernel_initializer=tf.glorot_normal_initializer(),
+                 bias_initializer=tf.zeros_initializer(),
+                 kernel_regularizer=None,
+                 bias_regularizer=None,
+                 activity_regularizer=None,
+                 kernel_constraint=None,
+                 bias_constraint=None,
+                 trainable=True,
+                 name=None,
+                 **kwargs):
+        super(DenseBinaryLayer, self).__init__(
+            units=output_dim,
+            activation=activation,
+            use_bias=use_bias,
+            kernel_initializer=kernel_initializer,
+            bias_initializer=bias_initializer,
+            kernel_regularizer=kernel_regularizer,
+            bias_regularizer=bias_regularizer,
+            activity_regularizer=activity_regularizer,
+            kernel_constraint=kernel_constraint,
+            bias_constraint=bias_constraint,
+            trainable=trainable,
+            name=name,
+            **kwargs
+        )
+
+        self.binary = binary
+        self.stochastic = stochastic
+
+        self.H = H
+        self.W_LR_scale = W_LR_scale
+
+        all_layers.append(self)
+
+    def build(self, input_shape):
+        num_inputs = tensor_shape.TensorShape(input_shape).as_list()[-1]
+        num_units = self.units
+        print(num_units)
+
+        if self.H == "Glorot":
+            self.H = np.float32(np.sqrt(1.5 / (num_inputs + num_units)))  # weight init method
+        self.W_LR_scale = np.float32(1. / np.sqrt(1.5 / (num_inputs + num_units)))  # each layer learning rate
+        print("H = ", self.H)
+        print("LR scale = ", self.W_LR_scale)
+
+        self.kernel_initializer = tf.random_uniform_initializer(-self.H, self.H)
+        self.kernel_constraint = lambda w: tf.clip_by_value(w, -self.H, self.H)
+
+        '''
+        self.b_kernel = self.add_variable('binary_weight',
+                                    shape=[input_shape[-1], self.units],
+                                    initializer=self.kernel_initializer,
+                                    regularizer=None,
+                                    constraint=None,
+                                    dtype=self.dtype,
+                                    trainable=False)  # add_variable must execute before call build()
+        '''
+        self.b_kernel = self.add_variable('binary_weight',
+                                          shape=[input_shape[-1], self.units],
+                                          initializer=tf.random_uniform_initializer(-self.H, self.H),
+                                          regularizer=None,
+                                          constraint=None,
+                                          dtype=self.dtype,
+                                          trainable=False)
+
+        super(DenseBinaryLayer, self).build(input_shape)
+
+        # tf.add_to_collection('real', self.trainable_variables)
+        # tf.add_to_collection(self.name + '_binary', self.kernel)  # layer-wise group
+        # tf.add_to_collection('binary', self.kernel)  # global group
+
+    def call(self, inputs):
+        inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
+        shape = inputs.get_shape().as_list()
+
+        # binarization weight
+        self.b_kernel = binarization(self.kernel, self.H)
+        # r_kernel = self.kernel
+        # self.kernel = self.b_kernel
+
+        print("shape: ", len(shape))
+        if len(shape) > 2:
+            # Broadcasting is required for the inputs.
+            outputs = standard_ops.tensordot(inputs, self.b_kernel, [[len(shape) - 1], [0]])
+            # Reshape the output back to the original ndim of the input.
+            if context.in_graph_mode():
+                output_shape = shape[:-1] + [self.units]
+                outputs.set_shape(output_shape)
+        else:
+            outputs = standard_ops.matmul(inputs, self.b_kernel)
+
+        # restore weight
+        # self.kernel = r_kernel
+
+        if self.use_bias:
+            outputs = nn.bias_add(outputs, self.bias)
+        if self.activation is not None:
+            return self.activation(outputs)
+        return outputs
+
+
+# Functional interface for the Dense_BinaryLayer class.
+def dense_binary(
+        inputs, units,
+        activation=None,
+        use_bias=True,
+        binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
+        kernel_initializer=tf.glorot_normal_initializer(),
+        bias_initializer=tf.zeros_initializer(),
+        kernel_regularizer=None,
+        bias_regularizer=None,
+        activity_regularizer=None,
+        kernel_constraint=None,
+        bias_constraint=None,
+        trainable=True,
+        name=None,
+        reuse=None):
+    layer = DenseBinaryLayer(units,
+                             activation=activation,
+                             use_bias=use_bias,
+                             binary=binary, stochastic=stochastic, H=H, W_LR_scale=W_LR_scale,
+                             kernel_initializer=kernel_initializer,
+                             bias_initializer=bias_initializer,
+                             kernel_regularizer=kernel_regularizer,
+                             bias_regularizer=bias_regularizer,
+                             activity_regularizer=activity_regularizer,
+                             kernel_constraint=kernel_constraint,
+                             bias_constraint=bias_constraint,
+                             trainable=trainable,
+                             name=name,
+                             dtype=inputs.dtype.base_dtype,
+                             _scope=name,
+                             _reuse=reuse)
+    return layer.apply(inputs)
+
+
+# Not yet binarized
+class BatchNormalization(keras.layers.BatchNormalization):
+    def __init__(self,
+                 axis=-1,
+                 momentum=0.99,
+                 epsilon=1e-3,
+                 center=True,
+                 scale=True,
+                 beta_initializer=tf.zeros_initializer(),
+                 gamma_initializer=tf.ones_initializer(),
+                 moving_mean_initializer=tf.zeros_initializer(),
+                 moving_variance_initializer=tf.ones_initializer(),
+                 beta_regularizer=None,
+                 gamma_regularizer=None,
+                 beta_constraint=None,
+                 gamma_constraint=None,
+                 renorm=False,
+                 renorm_clipping=None,
+                 renorm_momentum=0.99,
+                 fused=None,
+                 trainable=True,
+                 name=None,
+                 **kwargs):
+        super(BatchNormalization, self).__init__(
+            axis=axis,
+            momentum=momentum,
+            epsilon=epsilon,
+            center=center,
+            scale=scale,
+            beta_initializer=beta_initializer,
+            gamma_initializer=gamma_initializer,
+            moving_mean_initializer=moving_mean_initializer,
+            moving_variance_initializer=moving_variance_initializer,
+            beta_regularizer=beta_regularizer,
+            gamma_regularizer=gamma_regularizer,
+            beta_constraint=beta_constraint,
+            gamma_constraint=gamma_constraint,
+            renorm=renorm,
+            renorm_clipping=renorm_clipping,
+            renorm_momentum=renorm_momentum,
+            fused=fused,
+            trainable=trainable,
+            name=name,
+            **kwargs)
+        # all_layers.append(self)
+
+    def build(self, input_shape):
+        super(BatchNormalization, self).build(input_shape)
+        self.W_LR_scale = np.float32(1.)
+
+
+# Functional interface for the batch normalization layer.
+def batch_normalization(
+        inputs,
+        axis=-1,
+        momentum=0.99,
+        epsilon=1e-3,
+        center=True,
+        scale=True,
+        beta_initializer=tf.zeros_initializer(),
+        gamma_initializer=tf.ones_initializer(),
+        moving_mean_initializer=tf.zeros_initializer(),
+        moving_variance_initializer=tf.ones_initializer(),
+        beta_regularizer=None,
+        gamma_regularizer=None,
+        beta_constraint=None,
+        gamma_constraint=None,
+        training=False,
+        trainable=True,
+        name=None,
+        reuse=None,
+        renorm=False,
+        renorm_clipping=None,
+        renorm_momentum=0.99,
+        fused=None):
+    layer = BatchNormalization(
+        axis=axis,
+        momentum=momentum,
+        epsilon=epsilon,
+        center=center,
+        scale=scale,
+        beta_initializer=beta_initializer,
+        gamma_initializer=gamma_initializer,
+        moving_mean_initializer=moving_mean_initializer,
+        moving_variance_initializer=moving_variance_initializer,
+        beta_regularizer=beta_regularizer,
+        gamma_regularizer=gamma_regularizer,
+        beta_constraint=beta_constraint,
+        gamma_constraint=gamma_constraint,
+        renorm=renorm,
+        renorm_clipping=renorm_clipping,
+        renorm_momentum=renorm_momentum,
+        fused=fused,
+        trainable=trainable,
+        name=name,
+        dtype=inputs.dtype.base_dtype,
+        _reuse=reuse,
+        _scope=name
+    )
+    return layer.apply(inputs, training=training)
+
+
+class AdamOptimizer(optimizer.Optimizer):
+    """Optimizer that implements the Adam algorithm.
+    See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+    ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+    """
+
+    def __init__(self, weight_scale, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
+                 use_locking=False, name="Adam"):
+        super(AdamOptimizer, self).__init__(use_locking, name)
+        self._lr = learning_rate
+        self._beta1 = beta1
+        self._beta2 = beta2
+        self._epsilon = epsilon
+
+        # BNN weight scale factor
+        self._weight_scale = weight_scale
+
+        # Tensor versions of the constructor arguments, created in _prepare().
+        self._lr_t = None
+        self._beta1_t = None
+        self._beta2_t = None
+        self._epsilon_t = None
+
+        # Variables to accumulate the powers of the beta parameters.
+        # Created in _create_slots when we know the variables to optimize.
+        self._beta1_power = None
+        self._beta2_power = None
+
+        # Created in SparseApply if needed.
+        self._updated_lr = None
+
+    def _get_beta_accumulators(self):
+        return self._beta1_power, self._beta2_power
+
+    def _non_slot_variables(self):
+        return self._get_beta_accumulators()
+
+    def _create_slots(self, var_list):
+        first_var = min(var_list, key=lambda x: x.name)
+
+        create_new = self._beta1_power is None
+        if not create_new and context.in_graph_mode():
+            create_new = (self._beta1_power.graph is not first_var.graph)
+
+        if create_new:
+            with ops.colocate_with(first_var):
+                self._beta1_power = variable_scope.variable(self._beta1,
+                                                            name="beta1_power",
+                                                            trainable=False)
+                self._beta2_power = variable_scope.variable(self._beta2,
+                                                            name="beta2_power",
+                                                            trainable=False)
+        # Create slots for the first and second moments.
+        for v in var_list:
+            self._zeros_slot(v, "m", self._name)
+            self._zeros_slot(v, "v", self._name)
+
+    def _prepare(self):
+        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
+        self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
+        self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
+        self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+
+    def _apply_dense(self, grad, var):
+        m = self.get_slot(var, "m")
+        v = self.get_slot(var, "v")
+
+        # for BNN kernel
+        # origin version clipping weight method is new_w = old_w + scale*(new_w - old_w)
+        # and adam update function is new_w = old_w - lr_t * m_t / (sqrt(v_t) + epsilon)
+        # so subtitute adam function into weight clipping
+        # new_w = old_w - (scale * lr_t * m_t) / (sqrt(v_t) + epsilon)
+        scale = self._weight_scale[var.name] / 4
+
+        return training_ops.apply_adam(
+            var, m, v,
+            math_ops.cast(self._beta1_power, var.dtype.base_dtype),
+            math_ops.cast(self._beta2_power, var.dtype.base_dtype),
+            math_ops.cast(self._lr_t * scale, var.dtype.base_dtype),
+            math_ops.cast(self._beta1_t, var.dtype.base_dtype),
+            math_ops.cast(self._beta2_t, var.dtype.base_dtype),
+            math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
+            grad, use_locking=self._use_locking).op
+
+    def _resource_apply_dense(self, grad, var):
+        m = self.get_slot(var, "m")
+        v = self.get_slot(var, "v")
+
+        return training_ops.resource_apply_adam(
+            var.handle, m.handle, v.handle,
+            math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
+            math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
+            math_ops.cast(self._lr_t, grad.dtype.base_dtype),
+            math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
+            math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
+            math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
+            grad, use_locking=self._use_locking)
+
+    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
+        beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
+        beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
+        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+        lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+        # m_t = beta1 * m + (1 - beta1) * g_t
+        m = self.get_slot(var, "m")
+        m_scaled_g_values = grad * (1 - beta1_t)
+        m_t = state_ops.assign(m, m * beta1_t,
+                               use_locking=self._use_locking)
+        with ops.control_dependencies([m_t]):
+            m_t = scatter_add(m, indices, m_scaled_g_values)
+        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+        v = self.get_slot(var, "v")
+        v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+        v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
+        with ops.control_dependencies([v_t]):
+            v_t = scatter_add(v, indices, v_scaled_g_values)
+        v_sqrt = math_ops.sqrt(v_t)
+        var_update = state_ops.assign_sub(var,
+                                          lr * m_t / (v_sqrt + epsilon_t),
+                                          use_locking=self._use_locking)
+        return control_flow_ops.group(*[var_update, m_t, v_t])
+
+    def _apply_sparse(self, grad, var):
+        return self._apply_sparse_shared(
+            grad.values, var, grad.indices,
+            lambda x, i, v: state_ops.scatter_add(  # pylint: disable=g-long-lambda
+                x, i, v, use_locking=self._use_locking))
+
+    def _resource_scatter_add(self, x, i, v):
+        with ops.control_dependencies(
+                [resource_variable_ops.resource_scatter_add(
+                    x.handle, i, v)]):
+            return x.value()
+
+    def _resource_apply_sparse(self, grad, var, indices):
+        return self._apply_sparse_shared(
+            grad, var, indices, self._resource_scatter_add)
+
+    def _finish(self, update_ops, name_scope):
+        # Update the power accumulators.
+        with ops.control_dependencies(update_ops):
+            with ops.colocate_with(self._beta1_power):
+                update_beta1 = self._beta1_power.assign(
+                    self._beta1_power * self._beta1_t,
+                    use_locking=self._use_locking)
+                update_beta2 = self._beta2_power.assign(
+                    self._beta2_power * self._beta2_t,
+                    use_locking=self._use_locking)
+        return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
+                                      name=name_scope)
+
+
+def get_all_layers():
+    return all_layers
+
+
+def get_all_LR_scale():
+    return {layer.kernel.name: layer.W_LR_scale for layer in get_all_layers()}
+
+
+# This function computes the gradient of the binary weights
+def compute_grads(loss, opt):
+    layers = get_all_layers()
+    grads_list = []
+    update_weights = []
+
+    for layer in layers:
+
+        # refer to self.params[self.W]=set(['binary'])
+        # The list can optionally be filtered by specifying tags as keyword arguments.
+        # For example,
+        # ``trainable=True`` will only return trainable parameters, and
+        # ``regularizable=True`` will only return parameters that can be regularized
+        # function return, e.g. [W, b] for dense layer
+        params = tf.get_collection(layer.name + "_binary")
+        if params:
+            # print(params[0].name)
+            # theano.grad(cost, wrt) -> d(cost)/d(wrt)
+            # wrt – with respect to which we want gradients
+            # http://blog.csdn.net/shouhuxianjian/article/details/46517143
+            # http://blog.csdn.net/qq_33232071/article/details/52806630
+            # grad = opt.compute_gradients(loss, layer.b_kernel)  # origin version
+            grad = opt.compute_gradients(loss, params[0])  # modify
+            print("grad: ", grad)
+            grads_list.append(grad[0][0])
+            update_weights.extend(params)
+
+    print(grads_list)
+    print(update_weights)
+    return zip(grads_list, update_weights)

+ 22 - 0
models/data.py

@@ -30,3 +30,25 @@ class BinaryOneHotGenerator(Sequence):
 
     def __getitem__(self, idx):
         return self.x, self.x
+
+
+class BinaryGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2, dtype=tf.bool):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.dtype = dtype
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = tf.convert_to_tensor(x, dtype=self.dtype)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x

+ 194 - 168
models/end_to_end.py

@@ -1,122 +1,14 @@
+import itertools
 import math
 
-from tensorflow import keras
 import tensorflow as tf
 import numpy as np
 import matplotlib.pyplot as plt
-from matplotlib import collections as matcoll
+from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 
-
-class ExtractCentralMessage(layers.Layer):
-    def __init__(self, messages_per_block, samples_per_symbol):
-        """
-        :param messages_per_block: Total number of messages in transmission block
-        :param samples_per_symbol: Number of samples per transmitted symbol
-        """
-        super(ExtractCentralMessage, self).__init__()
-
-        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
-        i = np.identity(samples_per_symbol)
-        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
-        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
-        temp_w[begin:end, :] = i
-
-        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
-
-    def call(self, inputs, **kwargs):
-        return tf.matmul(inputs, self.w)
-
-
-class DigitizationLayer(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 lpf_cutoff=32e9,
-                 q_stddev=0.1):
-        """
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(DigitizationLayer, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(q_stddev)
-        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
-        temp = np.ones(freq.shape)
-
-        for idx, val in np.ndenumerate(freq):
-            if np.abs(val) > lpf_cutoff:
-                temp[idx] = 0
-
-        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
-
-    def call(self, inputs, **kwargs):
-        complex_in = tf.cast(inputs, dtype=tf.complex64)
-        val_f = tf.signal.fft(complex_in)
-        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
-        filtered_t = tf.signal.ifft(filtered_f)
-        real_t = tf.cast(filtered_t, dtype=tf.float32)
-        noisy = self.noise_layer.call(real_t, training=True)
-        return noisy
-
-
-class OpticalChannel(layers.Layer):
-    def __init__(self,
-                 fs,
-                 num_of_samples,
-                 dispersion_factor,
-                 fiber_length,
-                 lpf_cutoff=32e9,
-                 rx_stddev=0.01,
-                 q_stddev=0.01):
-        """
-        :param fs: Sampling frequency of the simulation in Hz
-        :param num_of_samples: Total number of samples in the input
-        :param dispersion_factor: Dispersion factor in s^2/km
-        :param fiber_length: Length of fiber to model in km
-        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
-        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
-        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
-        """
-        super(OpticalChannel, self).__init__()
-
-        self.noise_layer = layers.GaussianNoise(rx_stddev)
-        self.digitization_layer = DigitizationLayer(fs=fs,
-                                                    num_of_samples=num_of_samples,
-                                                    lpf_cutoff=lpf_cutoff,
-                                                    q_stddev=q_stddev)
-        self.flatten_layer = layers.Flatten()
-
-        self.fs = fs
-        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
-        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
-
-    def call(self, inputs, **kwargs):
-        # DAC LPF and noise
-        dac_out = self.digitization_layer(inputs)
-
-        # Chromatic Dispersion
-        complex_val = tf.cast(dac_out, dtype=tf.complex128)
-        val_f = tf.signal.fft(complex_val)
-        disp_f = tf.math.multiply(val_f, self.multiplier)
-        disp_t = tf.signal.ifft(disp_f)
-
-        # Squared-Law Detection
-        pd_out = tf.square(tf.abs(disp_t))
-
-        # Casting back to floatx
-        real_val = tf.cast(pd_out, dtype=tf.float32)
-
-        # Adding photo-diode receiver noise
-        rx_signal = self.noise_layer.call(real_val, training=True)
-
-        # ADC LPF and noise
-        adc_out = self.digitization_layer(rx_signal)
-
-        return adc_out
+from layers import ExtractCentralMessage, BitsToSymbols, SymbolsToBits, OpticalChannel, DigitizationLayer
 
 
 class EndToEndAutoencoder(tf.keras.Model):
@@ -124,8 +16,13 @@ class EndToEndAutoencoder(tf.keras.Model):
                  cardinality,
                  samples_per_symbol,
                  messages_per_block,
-                 channel):
+                 channel,
+                 bit_mapping=False):
         """
+        The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
+        of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
+        interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
+
         :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
         :param samples_per_symbol: Number of samples per transmitted symbol
         :param messages_per_block: Total number of messages in transmission block
@@ -135,99 +32,218 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Labelled M in paper
         self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
         # Labelled n in paper
         self.samples_per_symbol = samples_per_symbol
+
         # Labelled N in paper
         if messages_per_block % 2 == 0:
             messages_per_block += 1
         self.messages_per_block = messages_per_block
+
         # Channel Model Layer
         if isinstance(channel, layers.Layer):
             self.channel = tf.keras.Sequential([
                 layers.Flatten(),
                 channel,
                 ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
-            ])
+            ], name="channel_model")
         else:
             raise TypeError("Channel must be a subclass of keras.layers.layer!")
 
+        # Boolean identifying if bit mapping is to be learnt
+        self.bit_mapping = bit_mapping
+
+        # other parameters/metrics
+        self.symbol_error_rate = None
+        self.bit_error_rate = None
+        self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
+
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
+
+        # Layer configuration for the case when bit mapping is to be learnt
+        if self.bit_mapping:
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
+                BitsToSymbols(self.cardinality),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
+            ]
+            decoding_layers = [
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
+                # layers.Dense(2 * self.cardinality),
+                # layers.LeakyReLU(alpha=0.01),
+                layers.Dense(self.bits_per_symbol, activation='sigmoid')
+            ]
+
+        # layer configuration for the case when only symbol mapping is to be learnt
+        else:
+            encoding_layers = [
+                layers.Input(shape=(self.messages_per_block, self.cardinality)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+                layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+                layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+                # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+                # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
+            ]
+            decoding_layers = [
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
+                layers.Dense(2 * self.cardinality),
+                layers.LeakyReLU(alpha=leaky_relu_alpha),
+                layers.Dense(self.cardinality, activation='softmax')
+            ]
+
         # Encoding Neural Network
         self.encoder = tf.keras.Sequential([
-            layers.Input(shape=(self.messages_per_block, self.cardinality)),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.samples_per_symbol),
-            layers.ReLU(max_value=1.0)
-        ])
+            *encoding_layers
+        ], name="encoding_model")
 
         # Decoding Neural Network
         self.decoder = tf.keras.Sequential([
-            layers.Dense(self.samples_per_symbol, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(2 * self.cardinality, activation='relu'),
-            layers.Dense(self.cardinality, activation='softmax')
-        ])
+            *decoding_layers
+        ], name="decoding_model")
 
     def generate_random_inputs(self, num_of_blocks, return_vals=False):
         """
+        A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
+
         :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
         consecutively to model ISI. The central message in a block is returned as the label for training.
         :param return_vals: If true, the raw decimal values of the input sequence will be returned
         """
-        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
         cat = [np.arange(self.cardinality)]
         enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
 
-        out = enc.fit_transform(rand_int)
-        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        if self.bit_mapping:
+            rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
 
-        mid_idx = int((self.messages_per_block-1)/2)
+            out = rand_int
+
+            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+            if return_vals:
+                return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        else:
+            rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
 
-        if return_vals:
-            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
-            return out_val, out_arr, out_arr[:, mid_idx, :]
+            out = enc.fit_transform(rand_int)
+
+            out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+
+            if return_vals:
+                out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+                return out_val, out_arr, out_arr[:, mid_idx, :]
 
         return out_arr, out_arr[:, mid_idx, :]
 
-    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
         """
+        Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
+
         :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
         :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
         :param train_size: Float less than 1 representing the proportion of the dataset to use for training
         :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
         """
-        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
-        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
-
-        opt = keras.optimizers.Adam(learning_rate=lr)
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        # TODO: Investigate different optimizers (with different learning rates and other parameters)
+        # SGD
+        # RMSprop
+        # Adam
+        # Adadelta
+        # Adagrad
+        # Adamax
+        # Nadam
+        # Ftrl
+
+        if self.bit_mapping:
+            loss_fn = losses.BinaryCrossentropy()
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
 
         self.compile(optimizer=opt,
-                     loss=losses.BinaryCrossentropy(),
+                     loss=loss_fn,
                      metrics=['accuracy'],
                      loss_weights=None,
                      weighted_metrics=None,
                      run_eagerly=False
                      )
 
-        self.fit(x=X_train,
+        history = self.fit(x=X_train,
                  y=y_train,
                  batch_size=batch_size,
-                 epochs=1,
+                 epochs=epochs,
                  shuffle=True,
                  validation_data=(X_test, y_test)
                  )
 
+    def test(self, num_of_blocks=1e4):
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+
+        y_out = self.call(X_test)
+
+        y_pred = tf.argmax(y_out, axis=1)
+        y_true = tf.argmax(y_test, axis=1)
+
+        self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
+
+        lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
+
+        bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+        bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
+
+        self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
+
+        print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
+        print("BIT ERROR RATE: {}".format(self.bit_error_rate))
+
+        pass
+
     def view_encoder(self):
-        # Generate inputs for encoder
-        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+        '''
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
+        '''
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
 
-        mid_idx = int((self.messages_per_block-1)/2)
+        if self.bit_mapping:
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
+            lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
 
-        idx = 0
-        for msg in messages:
-            msg[mid_idx, idx] = 1
-            idx += 1
+            idx = 0
+            for msg in messages:
+                msg[mid_idx] = lst[idx]
+                idx += 1
+
+        else:
+            # Generate inputs for encoder
+            messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+            idx = 0
+            for msg in messages:
+                msg[mid_idx, idx] = 1
+                idx += 1
 
         # Pass input through encoder and select middle messages
         encoded = self.encoder(messages)
@@ -235,23 +251,23 @@ class EndToEndAutoencoder(tf.keras.Model):
 
         # Compute subplot grid layout
         i = 0
-        while 2**i < self.cardinality**0.5:
+        while 2 ** i < self.cardinality ** 0.5:
             i += 1
 
-        num_x = int(2**i)
+        num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
 
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
 
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
-            t = t/self.channel.layers[1].fs
+            t = t / self.channel.layers[1].fs
 
         sym_idx = 0
         for y in range(num_y):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
                 axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
                 sym_idx += 1
 
@@ -265,34 +281,40 @@ class EndToEndAutoencoder(tf.keras.Model):
         pass
 
     def view_sample_block(self):
+        '''
+        Generates a random string of input message and encodes them. In addition to this, the output is passed through
+        digitization layer without any quantization noise for the low pass filtering.
+        '''
         # Generate a random block of messages
         val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
 
         # Encode and flatten the messages
         enc = self.encoder(inp)
         flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
 
         # Instantiate LPF layer
         lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
-                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
-                                q_stddev=0)
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
+                                sig_avg=0)
 
         # Apply LPF
         lpf_out = lpf(flat_enc)
 
         # Time axis
-        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
             t = t / self.channel.layers[1].fs
 
         # Plot the concatenated symbols before and after LPF
-        plt.figure(figsize=(2*self.messages_per_block, 6))
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
 
         for i in range(1, self.messages_per_block):
-            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
 
         plt.plot(t, flat_enc.numpy().T, 'x')
         plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
         plt.ylim((0, 1))
         plt.xlim((t.min(), t.max()))
         plt.title(str(val[0, :, 0]))
@@ -306,27 +328,31 @@ class EndToEndAutoencoder(tf.keras.Model):
         return outputs
 
 
-if __name__ == '__main__':
-
-    SAMPLING_FREQUENCY = 336e9
-    CARDINALITY = 32
-    SAMPLES_PER_SYMBOL = 24
-    MESSAGES_PER_BLOCK = 9
-    DISPERSION_FACTOR = -21.7 * 1e-24
-    FIBER_LENGTH = 50
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 32
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 0
 
+if __name__ == '__main__':
     optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
-                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
                                      dispersion_factor=DISPERSION_FACTOR,
                                      fiber_length=FIBER_LENGTH)
 
     ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
                                    samples_per_symbol=SAMPLES_PER_SYMBOL,
                                    messages_per_block=MESSAGES_PER_BLOCK,
-                                   channel=optical_channel)
+                                   channel=optical_channel,
+                                   bit_mapping=False)
 
-    ae_model.train(num_of_blocks=1e6, batch_size=100)
+    ae_model.train(num_of_blocks=1e5, epochs=5)
+    ae_model.test()
     ae_model.view_encoder()
     ae_model.view_sample_block()
-
+    # ae_model.summary()
+    ae_model.encoder.summary()
+    ae_model.channel.summary()
+    ae_model.decoder.summary()
     pass

+ 41 - 0
models/gray_code.py

@@ -0,0 +1,41 @@
+from scipy.spatial import Delaunay, Voronoi, voronoi_plot_2d
+import matplotlib.pyplot as plt
+import numpy as np
+import basic
+
+
+def get_gray_code(n: int):
+    return n ^ (n >> 1)
+
+
+def difference(sym0: int, sym1: int):
+    return bit_count(sym0 ^ sym1)
+
+
+def bit_count(i: int):
+    """
+    Hamming weight algorithm, just counts number of 1s
+    """
+    assert 0 <= i < 0x100000000
+    i = i - ((i >> 1) & 0x55555555)
+    i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
+    return (((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) & 0xffffffff) >> 24
+
+
+def compute_optimal(points, show_graph=False):
+    available = set(range(len(points)))
+    map = {}
+
+    vor = Voronoi(points)
+
+    if show_graph:
+        voronoi_plot_2d(vor)
+        plt.show()
+    pass
+
+
+if __name__ == '__main__':
+    a = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]])
+    # a = basic.load_alphabet('16qam', polar=False)
+    compute_optimal(a, show_graph=True)
+    pass

+ 155 - 18
models/layers.py

@@ -6,11 +6,15 @@ import itertools
 from tensorflow.keras import layers
 import tensorflow as tf
 import numpy as np
+import math
 
 
 class AwgnChannel(layers.Layer):
     def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
         """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
         :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
         """
         super(AwgnChannel, self).__init__(**kwargs)
@@ -23,6 +27,35 @@ class AwgnChannel(layers.Layer):
         return self.noise_layer.call(inputs, training=True)
 
 
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((9, 32))(out)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.all_syms)
+
+
 class ScaleAndOffset(layers.Layer):
     """
     Scales and offsets a tensor
@@ -37,27 +70,131 @@ class ScaleAndOffset(layers.Layer):
         return inputs * self.scale + self.offset
 
 
-class BitsToSymbol(layers.Layer):
-    def __init__(self, cardinality, **kwargs):
-        super().__init__(**kwargs)
-        self.cardinality = cardinality
-        n = int(np.log(self.cardinality, 2))
-        self.powers = tf.convert_to_tensor(
-            np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
-            dtype=tf.float32
-        )
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.samples_per_symbol = samples_per_symbol
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
 
     def call(self, inputs, **kwargs):
-        idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
-        return tf.one_hot(idx, self.cardinality)
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
 
 
-class SymbolToBits(layers.Layer):
-    def __init__(self, cardinality, **kwargs):
-        super().__init__(**kwargs)
-        n = int(np.log(cardinality, 2))
-        l = [list(i) for i in itertools.product([0, 1], repeat=n)]
-        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param sig_avg: Average signal amplitude
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.rx_stddev = rx_stddev
+
+        self.noise_layer = layers.GaussianNoise(self.rx_stddev)
+        self.digitization_layer = DigitizationLayer(
+            fs=fs,
+            num_of_samples=num_of_samples,
+            lpf_cutoff=lpf_cutoff,
+            sig_avg=sig_avg,
+            enob=enob
+        )
+        self.flatten_layer = layers.Flatten()
+
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(
+            np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64)
+        self.multiplier = tf.math.exp(
+            0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq))
 
     def call(self, inputs, **kwargs):
-        return tf.matmul(self.all_syms, inputs)
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out

+ 175 - 0
models/new_model.py

@@ -0,0 +1,175 @@
+import tensorflow as tf
+from tensorflow.keras import losses
+from layers import OpticalChannel, BitsToSymbols, SymbolsToBits
+from end_to_end import EndToEndAutoencoder
+import numpy as np
+import math
+
+from matplotlib import pyplot as plt
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             bit_mapping=False)
+
+        self.bit_error_rate = []
+        self.symbol_error_rate = []
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+        X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.BinaryCrossentropy(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        for _ in range(iters):
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 32
+SAMPLES_PER_SYMBOL = 32
+MESSAGES_PER_BLOCK = 9
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+
+if __name__ == '__main__':
+
+    distances = [0, 10, 20, 30, 40, 50, 60]
+    ser = []
+    ber = []
+
+    baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
+    bit_rate = math.log(CARDINALITY, 2) * baud_rate
+    snr = None
+
+    for d in distances:
+
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=d)
+
+        model = BitMappingModel(cardinality=CARDINALITY,
+                                samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                messages_per_block=MESSAGES_PER_BLOCK,
+                                channel=optical_channel)
+
+        if snr is None:
+            snr = model.e2e_model.snr
+        elif snr != model.e2e_model.snr:
+            print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
+
+        # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+
+        model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
+
+        ber.append(model.bit_error_rate[-1])
+        ser.append(model.symbol_error_rate[-1])
+
+        # plt.plot(model.bit_error_rate, label='BER')
+        # plt.plot(model.symbol_error_rate, label='SER')
+        # plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        # plt.legend()
+        # plt.show()
+        # model.summary()
+
+    # plt.plot(ber, label='BER')
+    # plt.plot(ser, label='SER')
+    # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    # plt.legend(ber)

+ 6 - 4
models/quantized_net.py

@@ -208,6 +208,7 @@ class QuantizedNeuralNetwork:
         else:
             # Define functions which will give you the output of the previous hidden layer
             # for both networks.
+            # try:
             prev_trained_output = Kfunction(
                 [self.trained_net_layers[0].input],
                 [self.trained_net_layers[layer_idx - 1].output],
@@ -219,7 +220,8 @@ class QuantizedNeuralNetwork:
             input_layer = self.trained_net_layers[0]
             input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
             batch = zeros((self.batch_size, *input_shape))
-
+            # except Exception:
+            #     pass
             # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
             # to reconsider how much memory you preallocate for batch, wX, and qX.
             feed_foward_batch_size = 500
@@ -275,12 +277,12 @@ class QuantizedNeuralNetwork:
         rad = self.alphabet_scalar * median(abs(W.flatten()))
 
         for neuron_idx in range(N_ell_plus_1):
-            self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            # self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
             tic = time()
             qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
             Q[:, neuron_idx] = qNeuron.q
 
-            self._log(f"\tdone. {time() - tic :.2f} seconds.")
+            # self._log(f"\tdone. {time() - tic :.2f} seconds.")
 
             self._update_weights(layer_idx, Q)
 
@@ -296,4 +298,4 @@ class QuantizedNeuralNetwork:
                 # Only quantize dense layers.
                 self._log(f"Quantizing layer {layer_idx}...")
                 self._quantize_layer(layer_idx)
-                self._log(f"done. {layer_idx}...")
+                self._log(f"done #{layer_idx} {layer}...")

+ 103 - 42
tests/min_test.py

@@ -8,9 +8,10 @@ from itertools import chain
 from sys import stdout
 
 from tensorflow.python.framework.errors_impl import NotFoundError
+from tensorflow.keras import backend as K
 
 import defs
-from graphs import get_SNR, get_AWGN_ber
+from graphs import get_SNR, get_AWGN_ber, show_train_history
 from models import basic
 from models.autoencoder import Autoencoder, view_encoder
 import matplotlib.pyplot as plt
@@ -19,13 +20,12 @@ import misc
 import numpy as np
 
 from models.basic import AlphabetDemod, AlphabetMod
+from models.data import BinaryGenerator
+from models.layers import BitsToSymbols, SymbolsToBits
 from models.optical_channel import OpticalChannel
 from models.quantized_net import QuantizedNeuralNetwork
 
 
-
-
-
 def _test_optics_autoencoder():
     ch = OpticalChannel(
         noise_level=-10,
@@ -71,7 +71,6 @@ def _test_optics_autoencoder():
     plt.show()
 
 
-
 def _test_autoencoder_pretrain():
     # aenc = Autoencoder(4, -25)
     # aenc.train(samples=1e6)
@@ -232,9 +231,10 @@ def _test_autoencoder_perf():
     aenc.train(epoch_size=1e3, epochs=10)
     # #
     m = aenc.N * aenc.parallel
-    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100 * m).reshape((-1, m)))
     x_train_enc = aenc.encoder(x_train)
     x_train = tf.cast(x_train, tf.float32)
+
     #
     # plt.plot(*get_SNR(
     #     aenc.get_modulator(),
@@ -261,6 +261,7 @@ def _test_autoencoder_perf():
             def representative_data_gen():
                 for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
                     yield [input_value]
+
             converter.representative_dataset = representative_data_gen
         tflite_model = converter.convert()
         tflite_models_dir = pathlib.Path("/tmp/tflite/")
@@ -394,23 +395,29 @@ def _test_autoencoder_perf():
 def _test_autoencoder_perf2():
     aenc = Autoencoder(2, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='2Bit AE')
 
     aenc = Autoencoder(3, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='3Bit AE')
 
     aenc = Autoencoder(4, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='4Bit AE')
 
     aenc = Autoencoder(5, -20)
     aenc.train(samples=3e6)
-    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='5Bit AE')
 
     for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
         try:
-            plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
+            plt.plot(
+                *get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                         start=-5, stop=15, ), '-', label=a.upper())
         except KeyboardInterrupt:
             break
         except Exception:
@@ -455,8 +462,8 @@ def _test_autoencoder_perf_qnn():
     view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
 
     batch_size = 25000
-    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
-    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size * m).reshape((-1, m)))
+    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000 * m).reshape((-1, m)))
     bits = [np.log2(i) for i in (32,)][0]
     alphabet_scalars = 2  # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
     num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
@@ -479,45 +486,99 @@ def _test_autoencoder_perf_qnn():
     qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
     view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
 
+    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
+    pass
 
 
+class BitAwareAutoencoder(Autoencoder):
+    def __init__(self, N, channel, **kwargs):
+        super().__init__(N, channel, **kwargs, cost=self.cost)
+        self.BITS = 2 ** N - 1
+        #  data_generator=BinaryGenerator,
+        # self.b2s_layer = BitsToSymbols(2**N)
+        # self.s2b_layer = SymbolsToBits(2**N)
+
+    def cost(self, y_true, y_pred):
+        y = tf.cast(y_true, dtype=tf.float32)
+        z0 = tf.math.argmax(y) / self.BITS
+        z1 = tf.math.argmax(y_pred) / self.BITS
+        error0 = y - y_pred
+        sqr_error0 = K.square(error0)  # mean of the square of the error
+        mean_sqr_error0 = K.mean(sqr_error0)  # square root of the mean of the square of the error
+        sme0 = K.sqrt(mean_sqr_error0)  # return the error
+
+        error1 = z0 - z1
+        sqr_error1 = K.square(error1)
+        mean_sqr_error1 = K.mean(sqr_error1)
+        sme1 = K.sqrt(mean_sqr_error1)
+        return sme0 + tf.cast(sme1 * 300, dtype=tf.float32)
+
+    # def call(self, x, **kwargs):
+    #     x1 = self.b2s_layer(x)
+    #     y = self.encoder(x1)
+    #     z = self.channel(y)
+    #     z1 = self.decoder(z)
+    #     return self.s2b_layer(z1)
+
+
+def _bit_aware_test():
+    aenc = BitAwareAutoencoder(6, -50, bipolar=True)
 
-    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+    try:
+        aenc.load_weights('ae_bitaware')
+    except NotFoundError:
+        pass
 
+    # try:
+    #     hist = aenc.train(
+    #         epochs=70,
+    #         epoch_size=1e3,
+    #         optimizer='adam',
+    #         # metrics=[tf.keras.metrics.Accuracy()],
+    #         # callbacks=[tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, min_delta=0.001)]
+    #     )
+    #     show_train_history(hist, "Autonecoder training history")
+    # except KeyboardInterrupt:
+    #     aenc.save_weights('ae_bitaware')
+    #     exit(0)
+    #
+    # aenc.save_weights('ae_bitaware')
 
-    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
-    pass
+    view_encoder(aenc.encoder, 6, title=f"4bit autoencoder alphabet")
 
-if __name__ == '__main__':
+    print("Computing BER/SNR for autoencoder")
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=1000000, steps=40,
+        start=-5, stop=15), '-', label='4Bit AE')
 
+    print("Computing BER/SNR for QAM16")
+    plt.plot(*get_SNR(
+        AlphabetMod('64qam', 10e6),
+        AlphabetDemod('64qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=1000000,
+        steps=40,
+        start=-5,
+        stop=15,
+    ), '-', label='16qam AWGN')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("16QAM vs autoencoder")
+    plt.show()
 
-    # plt.plot(*get_SNR(
-    #     AlphabetMod('16qam', 10e6),
-    #     AlphabetDemod('16qam', 10e6),
-    #     ber_func=get_AWGN_ber,
-    #     samples=100000,
-    #     steps=50,
-    #     start=-5,
-    #     stop=15,
-    # ), '-', label='16qam AWGN')
-    #
-    # plt.plot(*get_SNR(
-    #     AlphabetMod('16qam', 10e6),
-    #     AlphabetDemod('16qam', 10e6),
-    #     samples=100000,
-    #     steps=50,
-    #     start=-5,
-    #     stop=15,
-    # ), '-', label='16qam OPTICAL')
-    #
-    # plt.yscale('log')
-    # plt.grid()
-    # plt.xlabel('SNR dB')
-    # plt.show()
+
+if __name__ == '__main__':
+    _bit_aware_test()
 
     # _test_autoencoder_perf()
-    _test_autoencoder_perf_qnn()
+    # _test_autoencoder_perf_qnn()
     # _test_autoencoder_perf2()
     # _test_autoencoder_pretrain()
     # _test_optics_autoencoder()
-

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass