Bläddra i källkod

Merge remote-tracking branch 'origin/master' into Standalone_NN_devel

Min 4 år sedan
förälder
incheckning
63ed86d5e6
14 ändrade filer med 1736 tillägg och 235 borttagningar
  1. 8 1
      .gitignore
  2. 5 0
      alphabets/4pam.a
  3. 33 6
      defs.py
  4. 88 0
      graphs.py
  5. 51 53
      main.py
  6. 3 0
      misc.py
  7. 183 75
      models/autoencoder.py
  8. 51 85
      models/basic.py
  9. 32 0
      models/data.py
  10. 332 0
      models/end_to_end.py
  11. 63 0
      models/layers.py
  12. 65 15
      models/optical_channel.py
  13. 299 0
      models/quantized_net.py
  14. 523 0
      tests/min_test.py

+ 8 - 1
.gitignore

@@ -4,4 +4,11 @@ __pycache__
 *.pyo
 
 # Environments
-venv/
+venv/
+
+# Anything else
+*.log
+checkpoint
+*.index
+*.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]
+other/

+ 5 - 0
alphabets/4pam.a

@@ -0,0 +1,5 @@
+R
+00,0,0
+01,0.25,0
+10,0.5,0
+11,1,0

+ 33 - 6
defs.py

@@ -1,5 +1,30 @@
 import math
 import numpy as np
+import tensorflow as tf
+
+class Signal:
+
+    @property
+    def rect_x(self) -> np.ndarray:
+        return self.rect[:, 0]
+
+    @property
+    def rect_y(self) -> np.ndarray:
+        return self.rect[:, 1]
+
+    @property
+    def rect(self) -> np.ndarray:
+        raise NotImplemented("Not implemented")
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    def set_rect(self, mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    @property
+    def apf(self):
+        raise NotImplemented("Not implemented")
 
 
 class COMComponent:
@@ -13,12 +38,14 @@ class Channel(COMComponent):
     This model is just empty therefore just bypasses any input to output
     """
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> Signal:
+        raise NotImplemented("Need to define forward function")
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
         """
-        :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
-        :return: affected tuple of (amplitude, phase, frequency)
+        Forward operation optimised for tensorflow tensors
         """
-        raise NotImplemented("Need to define forward function")
+        raise NotImplemented("Need to define forward_tensor function")
 
 
 class ModComponent(COMComponent):
@@ -30,7 +57,7 @@ class ModComponent(COMComponent):
 
 class Modulator(ModComponent):
 
-    def forward(self, binary: np.ndarray) -> np.ndarray:
+    def forward(self, binary: np.ndarray) -> Signal:
         """
         :param binary: raw bytes as input (most be dtype=bool)
         :return: amplitude, phase, frequency
@@ -40,7 +67,7 @@ class Modulator(ModComponent):
 
 class Demodulator(ModComponent):
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> np.ndarray:
         """
         :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
         :return: binary resulting values (dtype=bool)

+ 88 - 0
graphs.py

@@ -0,0 +1,88 @@
+import math
+import os
+from multiprocessing import Pool
+
+from sklearn.metrics import accuracy_score
+
+from defs import Modulator, Demodulator, Channel
+from models.basic import AWGNChannel
+from misc import generate_random_bit_array
+from models.optical_channel import OpticalChannel
+import matplotlib.pyplot as plt
+import numpy as np
+
+CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
+
+
+def show_constellation(mod: Modulator, chan: Channel, demod: Demodulator, samples=1000):
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+
+    plt.plot(x_chan.rect_x[x], x_chan.rect_y[x], '+')
+    plt.plot(x_chan.rect_x[:, 0][~x], x_chan.rect_y[:, 1][~x], '+')
+    plt.plot(x_mod.rect_x[:, 0], x_mod.rect_y[:, 1], 'ro')
+    axes = plt.gca()
+    axes.set_xlim([-2, +2])
+    axes.set_ylim([-2, +2])
+    plt.grid()
+    plt.show()
+    print('Accuracy : ' + str())
+
+
+def get_ber(mod, chan, demod, samples=1000):
+    if samples % mod.N:
+        samples += mod.N - (samples % mod.N)
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+    return 1 - accuracy_score(x, x_demod)
+
+
+def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    for noise in ber_x:
+        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
+    return ber_x, ber_y
+
+
+def __calc_ber(packed):
+    # This function has to be outside get_Optical_ber in order to be pickled by pool
+    mod, demod, noise, length, pulse_shape, samples = packed
+    tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                length=length, pulse_shape=pulse_shape, sqrt_out=True)
+    return get_ber(mod, tx_channel, demod, samples=samples)
+
+
+def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
+    with Pool(CPU_COUNT) as pool:
+        packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
+        for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
+            ber_y.append(ber)
+            i += 1  # just offset by 1
+            print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i * 100 / len(ber_x):6.2f}%)", end='')
+    print()
+    return ber_x, ber_y
+
+
+def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
+    """
+    SNR for optics and RF should be calculated the same, that is A^2
+    Because P∝V² and P∝I²
+    """
+    x_mod = mod.forward(generate_random_bit_array(samples * mod.N))
+    sig_power = [A ** 2 for A in x_mod.amplitude]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+
+    noise_start = -start + av_sig_pow
+    noise_stop = -stop + av_sig_pow
+    ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
+    SNR = -ber_x + av_sig_pow
+    return SNR, ber_y

+ 51 - 53
main.py

@@ -1,48 +1,7 @@
 import matplotlib.pyplot as plt
 
-import numpy as np
-from sklearn.metrics import accuracy_score
-from models import basic
+import graphs
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
-import misc
-from models.autoencoder import Autoencoder, view_encoder
-
-
-def show_constellation(mod, chan, demod, samples=1000):
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-
-    x_mod_rect = misc.polar2rect(x_mod)
-    x_chan_rect = misc.polar2rect(x_chan)
-    plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
-    plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
-    plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
-    axes = plt.gca()
-    axes.set_xlim([-2, +2])
-    axes.set_ylim([-2, +2])
-    plt.grid()
-    plt.show()
-    print('Accuracy : ' + str())
-
-
-def get_ber(mod, chan, demod, samples=1000):
-    if samples % mod.N:
-        samples += mod.N - (samples % mod.N)
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-    return 1 - accuracy_score(x, x_demod)
-
-
-def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
-    ber_x = np.linspace(start, stop, steps)
-    ber_y = []
-    for noise in ber_x:
-        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
-    return ber_x, ber_y
 
 
 if __name__ == '__main__':
@@ -110,21 +69,60 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=20e3,
-            steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    # for l in np.logspace(start=0, stop=3, num=6):
+    #     plt.plot(*misc.get_SNR(
+    #         AlphabetMod('4pam', 10e6),
+    #         AlphabetDemod('4pam', 10e6),
+    #         samples=2000,
+    #         steps=200,
+    #         start=-5,
+    #         stop=20,
+    #         length=l,
+    #         pulse_shape='rcos'
+    #     ), '-', label=(str(int(l))+'km'))
+    #
+    # plt.yscale('log')
+    # # plt.gca().invert_xaxis()
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # # plt.ylabel('BER')
+    # plt.title("BER against Fiber length")
+    # plt.legend()
+    # plt.show()
+
+    for ps in ['rect', 'rcos', 'rrcos']:
+        plt.plot(*graphs.get_SNR(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=30000,
+            steps=100,
+            start=-5,
+            stop=20,
+            length=1,
+            pulse_shape=ps
+        ), '-', label=ps)
 
     plt.yscale('log')
-    plt.gca().invert_xaxis()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.show()
 
-    pass
+    pass

+ 3 - 0
misc.py

@@ -1,3 +1,4 @@
+
 import numpy as np
 import math
 import matplotlib.pyplot as plt
@@ -102,3 +103,5 @@ def generate_random_bit_array(size):
     arr = np.concatenate(p)
     np.random.shuffle(arr)
     return arr
+
+

+ 183 - 75
models/autoencoder.py

@@ -3,11 +3,19 @@ import numpy as np
 import tensorflow as tf
 
 from sklearn.metrics import accuracy_score
+from sklearn.model_selection import train_test_split
 from tensorflow.keras import layers, losses
 from tensorflow.keras.models import Model
-from functools import partial
+from tensorflow.python.keras.layers import LeakyReLU, ReLU
+
+# from functools import partial
 import misc
 import defs
+from models import basic
+import os
+# from tensorflow_model_optimization.python.core.quantization.keras import quantize, quantize_aware_activation
+from models.data import BinaryOneHotGenerator
+from models import layers as custom_layers
 
 latent_dim = 64
 
@@ -15,61 +23,98 @@ print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GP
 
 
 class AutoencoderMod(defs.Modulator):
-    def __init__(self, autoencoder):
-        super().__init__(2**autoencoder.N)
+    def __init__(self, autoencoder, encoder=None):
+        super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
+        self.encoder = encoder or autoencoder.encoder
 
-    def forward(self, binary: np.ndarray) -> np.ndarray:
-        reshaped = binary.reshape((-1, self.N))
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
         reshaped_ho = misc.bit_matrix2one_hot(reshaped)
-        encoded = self.autoencoder.encoder(reshaped_ho)
+        encoded = self.encoder(reshaped_ho)
         x = encoded.numpy()
-        x2 = x * 2 - 1
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
+
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
 
-        f = np.zeros(x2.shape[0])
-        x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
-        return x3
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
+        return basic.RFSignal(x3)
 
 
 class AutoencoderDemod(defs.Demodulator):
-    def __init__(self, autoencoder):
-        super().__init__(2**autoencoder.N)
+    def __init__(self, autoencoder, decoder=None):
+        super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
-
-    def forward(self, values: np.ndarray) -> np.ndarray:
-        rect = misc.polar2rect(values[:, [0, 1]])
-        decoded = self.autoencoder.decoder(rect).numpy()
-        result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
+        self.decoder = decoder or autoencoder.decoder
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+        decoded = self.decoder(val).numpy()
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
         return result.reshape(-1, )
 
 
 class Autoencoder(Model):
-    def __init__(self, N, noise):
+    def __init__(self, N, channel, signal_dim=2, parallel=1, all_onehot=True, bipolar=True, encoder=None, decoder=None):
         super(Autoencoder, self).__init__()
         self.N = N
-        self.encoder = tf.keras.Sequential()
-        self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        # self.encoder.add(layers.Dropout(0.2))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.encoder.add(layers.Dense(units=2, activation="sigmoid"))
-        # self.encoder.add(layers.ReLU(max_value=1.0))
-
-        self.decoder = tf.keras.Sequential()
-        self.decoder.add(tf.keras.Input(shape=(2,)))
-        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
-        # self.decoder.add(layers.Dropout(0.2))
-        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
-
+        self.parallel = parallel
+        self.signal_dim = signal_dim
+        self.bipolar = bipolar
+        self._input_shape = 2 ** (N * parallel) if all_onehot else (2 ** N) * parallel
+        if encoder is None:
+            self.encoder = tf.keras.Sequential()
+            self.encoder.add(layers.Input(shape=(self._input_shape,)))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.encoder.add(layers.Dropout(0.2))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            self.encoder.add(layers.Dense(units=signal_dim * parallel, activation="sigmoid"))
+            # self.encoder.add(layers.ReLU(max_value=1.0))
+            # self.encoder = quantize.quantize_model(self.encoder)
+        else:
+            self.encoder = encoder
+
+        if decoder is None:
+            self.decoder = tf.keras.Sequential()
+            self.decoder.add(tf.keras.Input(shape=(signal_dim * parallel,)))
+            self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # leaky relu with alpha=1 gives by far best results
+            self.decoder.add(LeakyReLU(alpha=1))
+            self.decoder.add(layers.Dense(units=self._input_shape, activation="softmax"))
+        else:
+            self.decoder = decoder
         # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
 
         self.mod = None
         self.demod = None
         self.compiled = False
 
-        # Divide by 2 because encoder outputs values between 0 and 1 instead of -1 and 1
-        self.noise = 10 ** (noise / 10)  # / 2
+        self.channel = tf.keras.Sequential()
+        if self.bipolar:
+            self.channel.add(custom_layers.ScaleAndOffset(2, -1, input_shape=(signal_dim * parallel,)))
+
+        if isinstance(channel, int) or isinstance(channel, float):
+            self.channel.add(custom_layers.AwgnChannel(noise_dB=channel, input_shape=(signal_dim * parallel,)))
+        else:
+            if not isinstance(channel, tf.keras.layers.Layer):
+                raise ValueError("Channel is not a keras layer")
+            self.channel.add(channel)
 
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
 
@@ -83,35 +128,78 @@ class Autoencoder(Model):
         #     layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
         #     layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
         # ])
+    @property
+    def all_layers(self):
+        return self.encoder.layers + self.decoder.layers #self.channel.layers +
 
     def call(self, x, **kwargs):
-        encoded = self.encoder(x)
-        encoded = encoded * 2 - 1
-        # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
-        # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
-        noise = np.random.normal(0, 1, (1, 2)) * self.noise
-        noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
-        decoded = self.decoder(encoded + noisy)
-        return decoded
-
-    def train(self, samples=1e6):
-        if samples % self.N:
-            samples += self.N - (samples % self.N)
-        x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
-        x_train_ho = misc.bit_matrix2one_hot(x_train)
-
-        test_samples = samples * 0.3
-        if test_samples % self.N:
-            test_samples += self.N - (test_samples % self.N)
-        x_test_array = misc.generate_random_bit_array(test_samples)
-        x_test = x_test_array.reshape((-1, self.N))
-        x_test_ho = misc.bit_matrix2one_hot(x_test)
+        y = self.encoder(x)
+        z = self.channel(y)
+        return self.decoder(z)
+
+    def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
+        alphabet = basic.load_alphabet(modulation, polar=False)
+
+        if not alphabet.shape[0] == self.N ** 2:
+            raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
+
+        x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
+        y_train = alphabet[x_train]
+        x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
+        for idx, x in np.ndenumerate(x_train):
+            x_train_ho[idx, x] = 1
+
+        x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
+        y_test = alphabet[x_test]
+        x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
+        for idx, x in np.ndenumerate(x_test):
+            x_test_ho[idx, x] = 1
+
+        self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.encoder.fit(x_train_ho, y_train,
+                         epochs=epochs,
+                         batch_size=batch_size,
+                         shuffle=shuffle,
+                         validation_data=(x_test_ho, y_test))
+
+    def fit_decoder(self, modulation, samples):
+        samples = int(samples * 1.3)
+        demod = basic.AlphabetDemod(modulation, 0)
+        x = np.random.rand(samples, 2) * 2 - 1
+        x = x.reshape((-1, 2))
+        f = np.zeros(x.shape[0])
+        xf = np.c_[x[:, 0], x[:, 1], f]
+        y = demod.forward(basic.RFSignal(misc.rect2polar(xf)))
+        y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
+
+        X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
+        self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
+        y_pred = self.decoder(X_test).numpy()
+        y_pred2 = np.zeros(y_test.shape, dtype=bool)
+        y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
+
+        print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
+
+    def train(self, epoch_size=3e3, epochs=5):
+        m = self.N * self.parallel
+        x_train = BinaryOneHotGenerator(size=epoch_size, shape=m)
+        x_test = BinaryOneHotGenerator(size=epoch_size*.3, shape=m)
+
+        # test_samples = epoch_size
+        # if test_samples % m:
+        #     test_samples += m - (test_samples % m)
+        # x_test_array = misc.generate_random_bit_array(test_samples)
+        # x_test = x_test_array.reshape((-1, m))
+        # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
         if not self.compiled:
             self.compile(optimizer='adam', loss=losses.MeanSquaredError())
             self.compiled = True
+            # self.build((self._input_shape, -1))
+            # self.summary()
 
-        self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
+        self.fit(x_train, shuffle=False, validation_data=x_test, epochs=epochs)
         # encoded_data = self.encoder(x_test_ho)
         # decoded_data = self.decoder(encoded_data).numpy()
 
@@ -126,12 +214,14 @@ class Autoencoder(Model):
         return self.demod
 
 
-def view_encoder(encoder, N, samples=1000):
+def view_encoder(encoder, N, samples=1000, title="Autoencoder generated alphabet"):
     test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
     test_values_ho = misc.bit_matrix2one_hot(test_values)
     mvector = np.array([2 ** i for i in range(N)], dtype=int)
     symbols = (test_values * mvector).sum(axis=1)
     encoded = encoder(test_values_ho).numpy()
+    if encoded.shape[1] == 1:
+        encoded = np.c_[encoded, np.zeros(encoded.shape[0])]
     # encoded = misc.polar2rect(encoded)
     for i in range(2 ** N):
         xy = encoded[symbols == i]
@@ -139,7 +229,7 @@ def view_encoder(encoder, N, samples=1000):
         plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
     plt.xlabel('Real')
     plt.ylabel('Imaginary')
-    plt.title("Autoencoder generated alphabet")
+    plt.title(title)
     # plt.legend()
     plt.show()
 
@@ -157,26 +247,44 @@ if __name__ == '__main__':
 
     n = 4
 
-    samples = 1e6
-    x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
-    x_train_ho = misc.bit_matrix2one_hot(x_train)
-    x_test_array = misc.generate_random_bit_array(samples * 0.3)
-    x_test = x_test_array.reshape((-1, n))
-    x_test_ho = misc.bit_matrix2one_hot(x_test)
+    # samples = 1e6
+    # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
+    # x_train_ho = misc.bit_matrix2one_hot(x_train)
+    # x_test_array = misc.generate_random_bit_array(samples * 0.3)
+    # x_test = x_test_array.reshape((-1, n))
+    # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
-    autoencoder = Autoencoder(n, -8)
+    autoencoder = Autoencoder(n, -15)
     autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
 
-    autoencoder.fit(x_train_ho, x_train_ho,
-                    epochs=1,
-                    shuffle=False,
-                    validation_data=(x_test_ho, x_test_ho))
-
-    encoded_data = autoencoder.encoder(x_test_ho)
-    decoded_data = autoencoder.decoder(encoded_data).numpy()
+    # autoencoder.fit_encoder(modulation='16qam',
+    #                         sample_size=2e6,
+    #                         train_size=0.8,
+    #                         epochs=1,
+    #                         batch_size=256,
+    #                         shuffle=True)
 
-    result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
-    print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
+    # view_encoder(autoencoder.encoder, n)
+    # autoencoder.fit_decoder(modulation='16qam', samples=2e6)
+    autoencoder.train()
     view_encoder(autoencoder.encoder, n)
 
+    # view_encoder(autoencoder.encoder, n)
+    # view_encoder(autoencoder.encoder, n)
+
+
+    # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
+    #
+    # autoencoder.fit(x_train_ho, x_train_ho,
+    #                 epochs=1,
+    #                 shuffle=False,
+    #                 validation_data=(x_test_ho, x_test_ho))
+    #
+    # encoded_data = autoencoder.encoder(x_test_ho)
+    # decoded_data = autoencoder.decoder(encoded_data).numpy()
+    #
+    # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
+    # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
+    # view_encoder(autoencoder.encoder, n)
+
     pass

+ 51 - 85
models/basic.py

@@ -4,78 +4,15 @@ import math
 import misc
 from scipy.spatial import cKDTree
 from os import path
+import tensorflow as tf
 
-ALPHABET_DIR = "./alphabets"
-# def _make_gray(n):
-#     if n <= 0:
-#         return []
-#     arr = ['0', '1']
-#     i = 2
-#     while True:
-#         if i >= 1 << n:
-#             break
-#         for j in range(i - 1, -1, -1):
-#             arr.append(arr[j])
-#         for j in range(i):
-#             arr[j] = "0" + arr[j]
-#         for j in range(i, 2 * i):
-#             arr[j] = "1" + arr[j]
-#         i = i << 1
-#     return list(map(lambda x: int(x, 2), arr))
-#
-#
-# def _gen_mary_alphabet(size, gray=True, polar=True):
-#     alphabet = np.zeros((size, 2))
-#     N = math.ceil(math.sqrt(size))
-#
-#     # if sqrt(size) != size^2 (not a perfect square),
-#     # skip defines how many corners to cut off.
-#     skip = 0
-#     if N ** 2 > size:
-#         skip = int(math.sqrt((N ** 2 - size) // 4))
-#
-#     step = 2 / (N - 1)
-#     skipped = 0
-#     for x in range(N):
-#         for y in range(N):
-#             i = x * N + y - skipped
-#             if i >= size:
-#                 break
-#             # Reverse y every odd column
-#             if x % 2 == 0 and N < 4:
-#                 y = N - y - 1
-#             if skip > 0:
-#                 if (x < skip or x + 1 > N - skip) and \
-#                         (y < skip or y + 1 > N - skip):
-#                     skipped += 1
-#                     continue
-#             # Exception for 3-ary alphabet, skip centre point
-#             if size == 8 and x == 1 and y == 1:
-#                 skipped += 1
-#                 continue
-#             alphabet[i, :] = [step * x - 1, step * y - 1]
-#     if gray:
-#         shape = alphabet.shape
-#         d1 = 4 if N > 4 else 2 ** N // 4
-#         g1 = np.array([0, 1, 3, 2])
-#         g2 = g1[:d1]
-#         hypershape = (d1, 4, 2)
-#         if N > 4:
-#             hypercube = alphabet.reshape(hypershape + (N-4, ))
-#             hypercube = hypercube[:, g1, :, :][g2, :, :, :]
-#         else:
-#             hypercube = alphabet.reshape(hypershape)
-#             hypercube = hypercube[:, g1, :][g2, :, :]
-#         alphabet = hypercube.reshape(shape)
-#     if polar:
-#         alphabet = misc.rect2polar(alphabet)
-#     return alphabet
+ALPHABET_DIR = path.join(path.dirname(__file__), "../alphabets")
 
 
 def load_alphabet(name, polar=True):
     apath = path.join(ALPHABET_DIR, name + '.a')
     if not path.exists(apath):
-        raise ValueError(f"Alphabet '{name}' does not exist")
+        raise ValueError(f"Alphabet '{name}' was not found in {path.abspath(apath)}")
     data = []
     indexes = []
     with open(apath, 'r') as f:
@@ -102,12 +39,12 @@ def load_alphabet(name, polar=True):
                 else:
                     raise ValueError()
                 if 'd' in header:
-                    p = y*math.pi/180
+                    p = y * math.pi / 180
                     y = math.sin(p) * x
                     x = math.cos(p) * x
                 data.append((x, y))
             except ValueError:
-                raise ValueError(f"Alphabet {name} line {i+1}: '{row}' has invalid values")
+                raise ValueError(f"Alphabet {name} line {i + 1}: '{row}' has invalid values")
 
     data2 = [None] * len(data)
     for i, d in enumerate(data):
@@ -118,6 +55,30 @@ def load_alphabet(name, polar=True):
     return arr
 
 
+class RFSignal(defs.Signal):
+    def __init__(self, array: np.ndarray):
+        self.amplitude = array[:, 0]
+        self.phase = array[:, 1]
+        self.frequency = array[:, 2]
+        self.symbols = array.shape[0]
+
+    @property
+    def rect(self) -> np.ndarray:
+        return misc.polar2rect(np.c_[self.amplitude, self.phase])
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        self.set_rect(np.c_[x_mat, y_mat])
+
+    def set_rect(self, mat: np.ndarray):
+        polar = misc.rect2polar(mat)
+        self.amplitude = polar[:, 0]
+        self.phase = polar[:, 1]
+
+    @property
+    def apf(self):
+        return np.c_[self.amplitude, self.phase, self.frequency]
+
+
 class BypassChannel(defs.Channel):
     def forward(self, values):
         return values
@@ -131,12 +92,17 @@ class AWGNChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-    def forward(self, values):
-        a = np.random.normal(0, 1, values.shape[0]) * self.noise
-        p = np.random.normal(0, 1, values.shape[0]) * self.noise
-        f = np.zeros(values.shape[0])
-        noise_mat = np.c_[a, p, f]
-        return values + noise_mat
+    def forward(self, values: RFSignal) -> RFSignal:
+        values.set_rect_xy(
+            values.rect_x + np.random.normal(0, 1, values.symbols) * self.noise,
+            values.rect_y + np.random.normal(0, 1, values.symbols) * self.noise,
+        )
+        return values
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
+        noise = tf.random.normal([2], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
+        tensor += noise * self.noise
+        return tensor
 
 
 class BPSKMod(defs.Modulator):
@@ -145,12 +111,12 @@ class BPSKMod(defs.Modulator):
         super().__init__(2, **kwargs)
         self.f = carrier_f
 
-    def forward(self, binary: np.ndarray):
+    def forward(self, binary):
         a = np.ones(binary.shape[0])
         p = np.zeros(binary.shape[0])
         p[binary == True] = np.pi
         f = np.zeros(binary.shape[0]) + self.f
-        return np.c_[a, p, f]
+        return RFSignal(np.c_[a, p, f])
 
 
 class BPSKDemod(defs.Demodulator):
@@ -167,11 +133,11 @@ class BPSKDemod(defs.Demodulator):
     def forward(self, values):
         # TODO: Channel noise simulator for frequency component?
         # for now we only care about amplitude and phase
-        ap = np.delete(values, 2, 1)
-        ap = misc.polar2rect(ap)
+        # ap = np.delete(values, 2, 1)
+        # ap = misc.polar2rect(ap)
 
-        result = np.ones(values.shape[0], dtype=bool)
-        result[ap[:, 0] > 0] = False
+        result = np.ones(values.symbols, dtype=bool)
+        result[values.rect_x[:, 0] > 0] = False
         return result
 
 
@@ -196,7 +162,7 @@ class AlphabetMod(defs.Modulator):
         a = values[:, 0]
         p = values[:, 1]
         f = np.zeros(reshaped.shape[0]) + self.f
-        return np.c_[a, p, f]  #, indices
+        return RFSignal(np.c_[a, p, f])  # , indices
 
 
 class AlphabetDemod(defs.Demodulator):
@@ -211,11 +177,11 @@ class AlphabetDemod(defs.Demodulator):
         self.ktree = cKDTree(self.alphabet)
 
     def forward(self, binary):
-        binary = binary[:, :2]  # ignore frequency
-        rbin = misc.polar2rect(binary)
-        indices = self.ktree.query(rbin)[1]
+        # binary = binary[:, :2]  # ignore frequency
+        # rbin = misc.polar2rect(binary)
+        indices = self.ktree.query(binary.rect)[1]
 
         # Converting indices to bite array
         # FIXME: unpackbits requires 8bit inputs, thus largest demodulation is 256-QAM
         values = np.unpackbits(np.array([indices], dtype=np.uint8).T, bitorder='little', axis=1)
-        return values[:, :self.N].reshape((-1,)).astype(bool)  #, indices
+        return values[:, :self.N].reshape((-1,)).astype(bool)  # , indices

+ 32 - 0
models/data.py

@@ -0,0 +1,32 @@
+import tensorflow as tf
+from tensorflow.keras.utils import Sequence
+
+# This creates pool of cpu resources
+# physical_devices = tf.config.experimental.list_physical_devices("CPU")
+# tf.config.experimental.set_virtual_device_configuration(
+#     physical_devices[0], [
+#         tf.config.experimental.VirtualDeviceConfiguration(),
+#         tf.config.experimental.VirtualDeviceConfiguration()
+#     ])
+import misc
+
+
+class BinaryOneHotGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = misc.bit_matrix2one_hot(x_train)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x

+ 332 - 0
models/end_to_end.py

@@ -0,0 +1,332 @@
+import math
+
+from tensorflow import keras
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import collections as matcoll
+from sklearn.preprocessing import OneHotEncoder
+from tensorflow.keras import layers, losses
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 q_stddev=0.1):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(q_stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 q_stddev=0.01):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.digitization_layer = DigitizationLayer(fs=fs,
+                                                    num_of_samples=num_of_samples,
+                                                    lpf_cutoff=lpf_cutoff,
+                                                    q_stddev=q_stddev)
+        self.flatten_layer = layers.Flatten()
+
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
+        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex128)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out
+
+
+class EndToEndAutoencoder(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        """
+        :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        :param messages_per_block: Total number of messages in transmission block
+        :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
+        """
+        super(EndToEndAutoencoder, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+        # Channel Model Layer
+        if isinstance(channel, layers.Layer):
+            self.channel = tf.keras.Sequential([
+                layers.Flatten(),
+                channel,
+                ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+            ])
+        else:
+            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+
+        # Encoding Neural Network
+        self.encoder = tf.keras.Sequential([
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.samples_per_symbol),
+            layers.ReLU(max_value=1.0)
+        ])
+
+        # Decoding Neural Network
+        self.decoder = tf.keras.Sequential([
+            layers.Dense(self.samples_per_symbol, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.cardinality, activation='softmax')
+        ])
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+        :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
+        consecutively to model ISI. The central message in a block is returned as the label for training.
+        :param return_vals: If true, the raw decimal values of the input sequence will be returned
+        """
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
+        cat = [np.arange(self.cardinality)]
+        enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+
+        out = enc.fit_transform(rand_int)
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+
+        mid_idx = int((self.messages_per_block-1)/2)
+
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        """
+        :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
+        :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
+        :param train_size: Float less than 1 representing the proportion of the dataset to use for training
+        :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
+        """
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+
+        opt = keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.BinaryCrossentropy(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=1,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def view_encoder(self):
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+        mid_idx = int((self.messages_per_block-1)/2)
+
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
+
+        # Pass input through encoder and select middle messages
+        encoded = self.encoder(messages)
+        enc_messages = encoded[:, mid_idx, :]
+
+        # Compute subplot grid layout
+        i = 0
+        while 2**i < self.cardinality**0.5:
+            i += 1
+
+        num_x = int(2**i)
+        num_y = int(self.cardinality / num_x)
+
+        # Plot all symbols
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+
+        t = np.arange(self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t/self.channel.layers[1].fs
+
+        sym_idx = 0
+        for y in range(num_y):
+            for x in range(num_x):
+                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
+                sym_idx += 1
+
+        for ax in axs.flat:
+            ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
+
+        for ax in axs.flat:
+            ax.label_outer()
+
+        plt.show()
+        pass
+
+    def view_sample_block(self):
+        # Generate a random block of messages
+        val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+        # Encode and flatten the messages
+        enc = self.encoder(inp)
+        flat_enc = layers.Flatten()(enc)
+
+        # Instantiate LPF layer
+        lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
+                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
+                                q_stddev=0)
+
+        # Apply LPF
+        lpf_out = lpf(flat_enc)
+
+        # Time axis
+        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t / self.channel.layers[1].fs
+
+        # Plot the concatenated symbols before and after LPF
+        plt.figure(figsize=(2*self.messages_per_block, 6))
+
+        for i in range(1, self.messages_per_block):
+            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+
+        plt.plot(t, flat_enc.numpy().T, 'x')
+        plt.plot(t, lpf_out.numpy().T)
+        plt.ylim((0, 1))
+        plt.xlim((t.min(), t.max()))
+        plt.title(str(val[0, :, 0]))
+        plt.show()
+        pass
+
+    def call(self, inputs, training=None, mask=None):
+        tx = self.encoder(inputs)
+        rx = self.channel(tx)
+        outputs = self.decoder(rx)
+        return outputs
+
+
+if __name__ == '__main__':
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 24
+    MESSAGES_PER_BLOCK = 9
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH)
+
+    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                   messages_per_block=MESSAGES_PER_BLOCK,
+                                   channel=optical_channel)
+
+    ae_model.train(num_of_blocks=1e6, batch_size=100)
+    ae_model.view_encoder()
+    ae_model.view_sample_block()
+
+    pass

+ 63 - 0
models/layers.py

@@ -0,0 +1,63 @@
+"""
+Custom Keras Layers for general use
+"""
+import itertools
+
+from tensorflow.keras import layers
+import tensorflow as tf
+import numpy as np
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
+        """
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__(**kwargs)
+        if noise_dB is not None:
+            # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
+            rx_stddev = 10 ** (noise_dB / 10.0)
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class ScaleAndOffset(layers.Layer):
+    """
+    Scales and offsets a tensor
+    """
+
+    def __init__(self, scale=1, offset=0, **kwargs):
+        super(ScaleAndOffset, self).__init__(**kwargs)
+        self.offset = offset
+        self.scale = scale
+
+    def call(self, inputs, **kwargs):
+        return inputs * self.scale + self.offset
+
+
+class BitsToSymbol(layers.Layer):
+    def __init__(self, cardinality, **kwargs):
+        super().__init__(**kwargs)
+        self.cardinality = cardinality
+        n = int(np.log(self.cardinality, 2))
+        self.powers = tf.convert_to_tensor(
+            np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
+            dtype=tf.float32
+        )
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
+        return tf.one_hot(idx, self.cardinality)
+
+
+class SymbolToBits(layers.Layer):
+    def __init__(self, cardinality, **kwargs):
+        super().__init__(**kwargs)
+        n = int(np.log(cardinality, 2))
+        l = [list(i) for i in itertools.product([0, 1], repeat=n)]
+        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(self.all_syms, inputs)

+ 65 - 15
models/optical_channel.py

@@ -3,17 +3,23 @@ import matplotlib.pyplot as plt
 import defs
 import numpy as np
 import math
-from scipy.fft import fft, ifft
+from numpy.fft import fft, fftfreq, ifft
+from commpy.filters import rrcosfilter, rcosfilter, rectfilter
+
+from models import basic
 
 
 class OpticalChannel(defs.Channel):
-    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, show_graphs=False, **kwargs):
+    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
+                 sqrt_out=False, show_graphs=False, **kwargs):
         """
         :param noise_level: Noise level in dB
         :param dispersion: dispersion coefficient is ps^2/km
         :param symbol_rate: Symbol rate of modulated signal in Hz
         :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
         :param length: fibre length in km
+        :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
+        :param sqrt_out: Take the root of the out to compensate for photodiode detection
         :param show_graphs: if graphs should be displayed or not
 
         Optical Channel class constructor
@@ -21,29 +27,45 @@ class OpticalChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-        self.dispersion = dispersion # * 1e-24  # Converting from ps^2/km to s^2/km
+        self.dispersion = dispersion * 1e-24  # Converting from ps^2/km to s^2/km
         self.symbol_rate = symbol_rate
         self.symbol_period = 1 / self.symbol_rate
         self.sample_rate = sample_rate
         self.sample_period = 1 / self.sample_rate
         self.length = length
+        self.pulse_shape = pulse_shape.strip().lower()
+        self.sqrt_out = sqrt_out
         self.show_graphs = show_graphs
 
     def __get_time_domain(self, symbol_vals):
         samples_per_symbol = int(self.sample_rate / self.symbol_rate)
         samples = int(symbol_vals.shape[0] * samples_per_symbol)
 
-        symbol_vals_a = np.repeat(symbol_vals, repeats=samples_per_symbol, axis=0)
-        t = np.linspace(start=0, stop=samples * self.sample_period, num=samples)
-        val_t = symbol_vals_a[:, 0] * np.cos(2 * math.pi * symbol_vals_a[:, 2] * t + symbol_vals_a[:, 1])
+        symbol_impulse = np.zeros(samples)
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        for i in range(symbol_vals.shape[0]):
+            symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
+
+        if self.pulse_shape == 'rrcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        elif self.pulse_shape == 'rcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        else:
+            self.filter_samples = samples_per_symbol
+            self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
+
+        val_t = np.convolve(symbol_impulse, self.h_filter)
+        t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
 
         return t, val_t
 
     def __time_to_frequency(self, values):
         val_f = fft(values)
-        f = np.linspace(0.0, 1 / (2 * self.sample_period), (values.size // 2))
-        f_neg = -1 * np.flip(f)
-        f = np.concatenate((f, f_neg), axis=0)
+        f = fftfreq(values.shape[-1])*self.sample_rate
         return f, val_f
 
     def __frequency_to_time(self, values):
@@ -79,6 +101,8 @@ class OpticalChannel(defs.Channel):
         return t, val_t
 
     def forward(self, values):
+        if hasattr(values, 'apf'):
+            values = values.apf
         # Converting APF representation to time-series
         t, val_t = self.__get_time_domain(values)
 
@@ -106,22 +130,48 @@ class OpticalChannel(defs.Channel):
         # Photodiode Detection
         t, val_t = self.__photodiode_detection(val_t)
 
+        # Symbol Decisions
+        idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
+                        self.symbol_period/self.sample_period, dtype='int64')
+        t_descision = self.sample_period * idx
+
         if self.show_graphs:
             plt.plot(t, val_t)
             plt.title('time domain (post-detection)')
             plt.show()
 
-        return t, val_t
+            plt.plot(t, val_t)
+            for xc in t_descision:
+                plt.axvline(x=xc, color='r')
+            plt.title('time domain (post-detection with decision times)')
+            plt.show()
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        out = np.zeros(values.shape)
+
+        out[:, 0] = val_t[idx]
+        out[:, 1] = values[:, 1]
+        out[:, 2] = values[:, 2]
+
+        if self.sqrt_out:
+            out[:, 0] = np.sqrt(out[:, 0])
+
+        return basic.RFSignal(out)
 
 
 if __name__ == '__main__':
     # Simple OOK modulation
-    num_of_symbols = 10
+    num_of_symbols = 100
     symbol_vals = np.zeros((num_of_symbols, 3))
 
     symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
-    symbol_vals[:, 2] = 10e6
+    symbol_vals[:, 2] = 40e9
+
+    channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
+                             sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
+    v = channel.forward(symbol_vals)
 
-    channel = OpticalChannel(noise_level=-20, dispersion=-21.7, symbol_rate=100e3,
-                             sample_rate=500e6, length=100, show_graphs=True)
-    time, v = channel.forward(symbol_vals)
+    rx = (v > 0.5).astype(int)
+    tru = np.sum(rx == symbol_vals[:, 0].astype(int))
+    print("Accuracy: {}".format(tru/num_of_symbols))

+ 299 - 0
models/quantized_net.py

@@ -0,0 +1,299 @@
+from numpy import (
+    array,
+    zeros,
+    dot,
+    median,
+    log2,
+    linspace,
+    argmin,
+    abs,
+)
+from scipy.linalg import norm
+from tensorflow.keras.backend import function as Kfunction
+from tensorflow.keras.models import Model, clone_model
+from collections import namedtuple
+from typing import List, Generator
+from time import time
+
+from models.autoencoder import Autoencoder
+
+QuantizedNeuron = namedtuple("QuantizedNeuron", ["layer_idx", "neuron_idx", "q"])
+QuantizedFilter = namedtuple(
+    "QuantizedFilter", ["layer_idx", "filter_idx", "channel_idx", "q_filtr"]
+)
+SegmentedData = namedtuple("SegmentedData", ["wX_seg", "qX_seg"])
+
+
+class QuantizedNeuralNetwork:
+    def __init__(
+            self,
+            network: Model,
+            batch_size: int,
+            get_data: Generator[array, None, None],
+            logger=None,
+            ignore_layers=[],
+            bits=log2(3),
+            alphabet_scalar=1,
+    ):
+
+        self.get_data = get_data
+
+        # The pre-trained network.
+        self.trained_net = network
+
+        # This copies the network structure but not the weights.
+        if isinstance(network, Autoencoder):
+            # The pre-trained network.
+            self.trained_net_layers = network.all_layers
+            self.quantized_net = Autoencoder(network.N, network.channel, bipolar=network.bipolar)
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.all_layers
+            # self.quantized_net.layers = self.quantized_net.layers
+        else:
+            # The pre-trained network.
+            self.trained_net_layers = network.layers
+            self.quantized_net = clone_model(network)
+            # Set all the weights to be the same a priori.
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.layers
+
+        self.batch_size = batch_size
+
+        self.alphabet_scalar = alphabet_scalar
+
+        # Create a dictionary encoding which layers are Dense, and what their dimensions are.
+        self.layer_dims = {
+            layer_idx: layer.get_weights()[0].shape
+            for layer_idx, layer in enumerate(network.layers)
+            if layer.__class__.__name__ == "Dense"
+        }
+
+        # This determines the alphabet. There will be 2**bits atoms in our alphabet.
+        self.bits = bits
+
+        # Construct the (unscaled) alphabet. Layers will scale this alphabet based on the
+        # distribution of that layer's weights.
+        self.alphabet = linspace(-1, 1, num=int(round(2 ** (bits))))
+
+        self.logger = logger
+
+        self.ignore_layers = ignore_layers
+
+    def _log(self, msg: str):
+        if self.logger:
+            self.logger.info(msg)
+        else:
+            print(msg)
+
+    def _bit_round(self, t: float, rad: float) -> float:
+        """Rounds a quantity to the nearest atom in the (scaled) quantization alphabet.
+
+        Parameters
+        -----------
+        t : float
+            The value to quantize.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        # Scale the alphabet appropriately.
+        layer_alphabet = rad * self.alphabet
+        return layer_alphabet[argmin(abs(layer_alphabet - t))]
+
+    def _quantize_weight(
+            self, w: float, u: array, X: array, X_tilde: array, rad: float
+    ) -> float:
+        """Quantizes a single weight of a neuron.
+
+        Parameters
+        -----------
+        w : float
+            The weight.
+        u : array ,
+            Residual vector.
+        X : array
+            Vector from the analog network's random walk.
+        X_tilde : array
+            Vector from the quantized network's random walk.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        if norm(X_tilde, 2) < 10 ** (-16):
+            return 0
+
+        if abs(dot(X_tilde, u)) < 10 ** (-10):
+            return self._bit_round(w, rad)
+
+        return self._bit_round(dot(X_tilde, u + w * X) / (norm(X_tilde, 2) ** 2), rad)
+
+    def _quantize_neuron(
+            self,
+            layer_idx: int,
+            neuron_idx: int,
+            wX: array,
+            qX: array,
+            rad=1,
+    ) -> QuantizedNeuron:
+        """Quantizes a single neuron in a Dense layer.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        neuron_idx : int,
+            Index of the neuron in the Dense layer.
+        wX : array
+            Layer input for the analog convolutional neural network.
+        qX : array
+            Layer input for the quantized convolutional neural network.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        QuantizedNeuron: NamedTuple
+            A tuple with the layer and neuron index, as well as the quantized neuron.
+        """
+
+        N_ell = wX.shape[1]
+        u = zeros(self.batch_size)
+        w = self.trained_net_layers[layer_idx].get_weights()[0][:, neuron_idx]
+        q = zeros(N_ell)
+        for t in range(N_ell):
+            q[t] = self._quantize_weight(w[t], u, wX[:, t], qX[:, t], rad)
+            u += w[t] * wX[:, t] - q[t] * qX[:, t]
+
+        return QuantizedNeuron(layer_idx=layer_idx, neuron_idx=neuron_idx, q=q)
+
+    def _get_layer_data(self, layer_idx: int, hf=None):
+        """Gets the input data for the layer at a given index.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the layer.
+        hf: hdf5 File object in write mode.
+            If provided, will write output to hdf5 file instead of returning directly.
+
+        Returns
+        -------
+        tuple: (array, array)
+            A tuple of arrays, with the first entry being the input for the analog network
+            and the latter being the input for the quantized network.
+        """
+
+        layer = self.trained_net_layers[layer_idx]
+        layer_data_shape = layer.input_shape[1:] if layer.input_shape[0] is None else layer.input_shape
+        wX = zeros((self.batch_size, *layer_data_shape))
+        qX = zeros((self.batch_size, *layer_data_shape))
+        if layer_idx == 0:
+            for sample_idx in range(self.batch_size):
+                try:
+                    wX[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+            qX = wX
+        else:
+            # Define functions which will give you the output of the previous hidden layer
+            # for both networks.
+            prev_trained_output = Kfunction(
+                [self.trained_net_layers[0].input],
+                [self.trained_net_layers[layer_idx - 1].output],
+            )
+            prev_quant_output = Kfunction(
+                [self.quantized_net_layers[0].input],
+                [self.quantized_net_layers[layer_idx - 1].output],
+            )
+            input_layer = self.trained_net_layers[0]
+            input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
+            batch = zeros((self.batch_size, *input_shape))
+
+            # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
+            # to reconsider how much memory you preallocate for batch, wX, and qX.
+            feed_foward_batch_size = 500
+            ctr = 0
+            for sample_idx in range(self.batch_size):
+                try:
+                    batch[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+
+            wX = prev_trained_output([batch])[0]
+            qX = prev_quant_output([batch])[0]
+
+        return (wX, qX)
+
+    def _update_weights(self, layer_idx: int, Q: array):
+        """Updates the weights of the quantized neural network given a layer index and
+        quantized weights.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Conv2D layer.
+        Q : array
+            The quantized weights.
+        """
+
+        # Update the quantized network. Use the same bias vector as in the analog network for now.
+        if self.trained_net_layers[layer_idx].use_bias:
+            bias = self.trained_net_layers[layer_idx].get_weights()[1]
+            self.quantized_net_layers[layer_idx].set_weights([Q, bias])
+        else:
+            self.quantized_net_layers[layer_idx].set_weights([Q])
+
+    def _quantize_layer(self, layer_idx: int):
+        """Quantizes a Dense layer of a multi-layer perceptron.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        """
+
+        W = self.trained_net_layers[layer_idx].get_weights()[0]
+        N_ell, N_ell_plus_1 = W.shape
+        # Placeholder for the weight matrix in the quantized network.
+        Q = zeros(W.shape)
+        N_ell_plus_1 = W.shape[1]
+        wX, qX = self._get_layer_data(layer_idx)
+
+        # Set the radius of the alphabet.
+        rad = self.alphabet_scalar * median(abs(W.flatten()))
+
+        for neuron_idx in range(N_ell_plus_1):
+            self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            tic = time()
+            qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
+            Q[:, neuron_idx] = qNeuron.q
+
+            self._log(f"\tdone. {time() - tic :.2f} seconds.")
+
+            self._update_weights(layer_idx, Q)
+
+    def quantize_network(self):
+        """Quantizes all Dense layers that are not specified by the list of ignored layers."""
+
+        # This must be done sequentially.
+        for layer_idx, layer in enumerate(self.trained_net_layers):
+            if (
+                    layer.__class__.__name__ == "Dense"
+                    and layer_idx not in self.ignore_layers
+            ):
+                # Only quantize dense layers.
+                self._log(f"Quantizing layer {layer_idx}...")
+                self._quantize_layer(layer_idx)
+                self._log(f"done. {layer_idx}...")

+ 523 - 0
tests/min_test.py

@@ -0,0 +1,523 @@
+"""
+These are some unstructured tests. Feel free to use this code for anything else
+"""
+
+import logging
+import pathlib
+from itertools import chain
+from sys import stdout
+
+from tensorflow.python.framework.errors_impl import NotFoundError
+
+import defs
+from graphs import get_SNR, get_AWGN_ber
+from models import basic
+from models.autoencoder import Autoencoder, view_encoder
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import misc
+import numpy as np
+
+from models.basic import AlphabetDemod, AlphabetMod
+from models.optical_channel import OpticalChannel
+from models.quantized_net import QuantizedNeuralNetwork
+
+
+
+
+
+def _test_optics_autoencoder():
+    ch = OpticalChannel(
+        noise_level=-10,
+        dispersion=-21.7,
+        symbol_rate=10e9,
+        sample_rate=400e9,
+        length=10,
+        pulse_shape='rcos',
+        sqrt_out=True
+    )
+
+    tf.executing_eagerly()
+
+    aenc = Autoencoder(4, channel=ch)
+    aenc.train(samples=1e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('4pam', 10e6),
+        AlphabetDemod('4pam', 10e6),
+        samples=30000,
+        steps=50,
+        start=-5,
+        stop=15,
+        length=1,
+        pulse_shape='rcos'
+    ), '-', label='4PAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder Performance")
+    plt.legend()
+    plt.savefig('optics_autoencoder.eps', format='eps')
+    plt.show()
+
+
+
+def _test_autoencoder_pretrain():
+    # aenc = Autoencoder(4, -25)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='Random AE')
+
+    aenc = Autoencoder(4, -25)
+    # aenc.fit_encoder('16qam', 3e4)
+    aenc.fit_decoder('16qam', 1e5)
+
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Pre-trained AE')
+
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Post-trained AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('16qam', 10e6),
+        AlphabetDemod('16qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("4Bit Autoencoder Performance")
+    plt.legend()
+    plt.show()
+
+
+class LiteTFMod(defs.Modulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+        pass
+
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
+        reshaped_ho = misc.bit_matrix2one_hot(reshaped)
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        x = np.zeros((len(reshaped_ho), output_shape[1]))
+        for i, ho in enumerate(reshaped_ho):
+            self.interpreter.set_tensor(input_index, ho.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            x[i] = self.interpreter.get_tensor(output_index)
+
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
+
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
+
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
+        return basic.RFSignal(x3)
+
+
+class LiteTFDemod(defs.Demodulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        decoded = np.zeros((len(val), output_shape[1]))
+        for i, v in enumerate(val):
+            self.interpreter.set_tensor(input_index, v.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            decoded[i] = self.interpreter.get_tensor(output_index)
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
+        return result.reshape(-1, )
+
+
+def _test_autoencoder_perf():
+    assert float(tf.__version__[:3]) >= 2.3
+
+    # aenc = Autoencoder(3, -15)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='3Bit AE')
+
+    # aenc = Autoencoder(4, -25, bipolar=True, dtype=tf.float64)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4Bit AE F64')
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    aenc.train(epoch_size=1e3, epochs=10)
+    # #
+    m = aenc.N * aenc.parallel
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
+    x_train_enc = aenc.encoder(x_train)
+    x_train = tf.cast(x_train, tf.float32)
+    #
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F32')
+    # # #
+    def save_tfline(model, name, types=None, ops=None, io_types=None, train_x=None):
+        converter = tf.lite.TFLiteConverter.from_keras_model(model)
+        if types is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_types = types
+        if ops is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_ops = ops
+        if io_types is not None:
+            converter.inference_input_type = io_types
+            converter.inference_output_type = io_types
+        if train_x is not None:
+            def representative_data_gen():
+                for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
+                    yield [input_value]
+            converter.representative_dataset = representative_data_gen
+        tflite_model = converter.convert()
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_models_dir.mkdir(exist_ok=True, parents=True)
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        tflite_model_file.write_bytes(tflite_model)
+
+    print("Saving models")
+
+    save_tfline(aenc.encoder, "default_enc")
+    save_tfline(aenc.decoder, "default_dec")
+    #
+    # save_tfline(aenc.encoder, "float16_enc", [tf.float16])
+    # save_tfline(aenc.decoder, "float16_dec", [tf.float16])
+    #
+    # save_tfline(aenc.encoder, "bfloat16_enc", [tf.bfloat16])
+    # save_tfline(aenc.decoder, "bfloat16_dec", [tf.bfloat16])
+
+    INT16X8 = tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
+    save_tfline(aenc.encoder, "int16x8_enc", ops=[INT16X8], train_x=x_train)
+    save_tfline(aenc.decoder, "int16x8_dec", ops=[INT16X8], train_x=x_train_enc)
+
+    # save_tfline(aenc.encoder, "int8_enc", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train)
+    # save_tfline(aenc.decoder, "int8_dec", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train_enc)
+
+    print("Testing BER vs SNR")
+    plt.plot(*get_SNR(
+        LiteTFMod("default_enc", aenc),
+        LiteTFDemod("default_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE F32')
+
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("float16_enc", aenc),
+    #     LiteTFDemod("float16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F16')
+    # #
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("bfloat16_enc", aenc),
+    #     LiteTFDemod("bfloat16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE BF16')
+    #
+    plt.plot(*get_SNR(
+        LiteTFMod("int16x8_enc", aenc),
+        LiteTFDemod("int16x8_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE I16x8')
+
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("Autoencoder with different precision data types")
+    plt.legend()
+    plt.savefig('autoencoder_compression.eps', format='eps')
+    plt.show()
+
+    view_encoder(aenc.encoder, 4)
+
+    # aenc = Autoencoder(5, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='5Bit AE')
+    #
+    # aenc = Autoencoder(6, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='6Bit AE')
+    #
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         ber_func=get_AWGN_ber,
+    #         samples=100000,
+    #         steps=50,
+    #         start=-5,
+    #         stop=15,
+    #     ), '-', label=scheme.upper())
+    #
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.title("Autoencoder vs defined modulations")
+    # plt.legend()
+    # plt.show()
+
+
+def _test_autoencoder_perf2():
+    aenc = Autoencoder(2, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
+
+    aenc = Autoencoder(3, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
+
+    aenc = Autoencoder(4, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
+
+    aenc = Autoencoder(5, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
+
+    for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
+        try:
+            plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
+        except KeyboardInterrupt:
+            break
+        except Exception:
+            pass
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder vs defined modulations")
+    plt.legend()
+    plt.savefig('autoencoder_mods.eps', format='eps')
+    plt.show()
+
+    # view_encoder(aenc.encoder, 2)
+
+
+def _test_autoencoder_perf_qnn():
+    fh = logging.FileHandler("model_quantizing.log", mode="w+")
+    fh.setLevel(logging.INFO)
+    sh = logging.StreamHandler(stream=stdout)
+    sh.setLevel(logging.INFO)
+
+    logger = logging.getLogger(__name__)
+    logger.setLevel(level=logging.INFO)
+    logger.addHandler(fh)
+    logger.addHandler(sh)
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    # aenc.encoder.save_weights('ae_enc.bin')
+    # aenc.decoder.save_weights('ae_dec.bin')
+    # aenc.encoder.load_weights('ae_enc.bin')
+    # aenc.decoder.load_weights('ae_dec.bin')
+    try:
+        aenc.load_weights('autoencoder')
+    except NotFoundError:
+        aenc.train(epoch_size=1e3, epochs=10)
+        aenc.save_weights('autoencoder')
+
+    aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
+
+    m = aenc.N * aenc.parallel
+    view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
+
+    batch_size = 25000
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
+    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
+    bits = [np.log2(i) for i in (32,)][0]
+    alphabet_scalars = 2  # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
+
+    # for b in (3, 6, 8, 12, 16, 24, 32, 48, 64):
+    get_data = (sample for sample in x_train)
+    for i in range(num_layers):
+        get_data = chain(get_data, (sample for sample in x_train))
+
+    qnn = QuantizedNeuralNetwork(
+        network=aenc,
+        batch_size=batch_size,
+        get_data=get_data,
+        logger=logger,
+        bits=np.log2(16),
+        alphabet_scalar=alphabet_scalars,
+    )
+
+    qnn.quantize_network()
+    qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+    view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
+
+
+
+
+    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+
+
+    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
+    pass
+
+if __name__ == '__main__':
+
+
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam AWGN')
+    #
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam OPTICAL')
+    #
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.show()
+
+    # _test_autoencoder_perf()
+    _test_autoencoder_perf_qnn()
+    # _test_autoencoder_perf2()
+    # _test_autoencoder_pretrain()
+    # _test_optics_autoencoder()
+