51 Achegas 72c0733fcf ... 4c4172b689

Autor SHA1 Mensaxe Data
  Min 4c4172b689 WIP results %!s(int64=4) %!d(string=hai) anos
  Min 33136feb09 Merge fix %!s(int64=4) %!d(string=hai) anos
  Min eb84764f2f Merge branch 'end-to-end' into Standalone_NN_devel %!s(int64=4) %!d(string=hai) anos
  Min 25c4fc2a23 Merge branch 'temp_e2e' into Standalone_NN_devel %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 0b990b8b42 refactoring %!s(int64=4) %!d(string=hai) anos
  Min 63ed86d5e6 Merge remote-tracking branch 'origin/master' into Standalone_NN_devel %!s(int64=4) %!d(string=hai) anos
  Min 86474cdd4f Quantised net fix %!s(int64=4) %!d(string=hai) anos
  Min 82e1427e16 Binary networks WIP %!s(int64=4) %!d(string=hai) anos
  Min 328596b16f Autoencoder encoder graph fix %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 0101829faf cleaned up code %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 8224f88367 variable length training added %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 16045ea3e0 ber vs length plot %!s(int64=4) %!d(string=hai) anos
  Min c2e69468bf Custom cost function test %!s(int64=4) %!d(string=hai) anos
  Min 91ef1d49d0 layer merge fix %!s(int64=4) %!d(string=hai) anos
  Min e6940a5e2c fixed merging end_to_end %!s(int64=4) %!d(string=hai) anos
  Min 4e4c333461 Added missing layers & fixed new_model %!s(int64=4) %!d(string=hai) anos
  Min fd19f379cb Merged layers file %!s(int64=4) %!d(string=hai) anos
  Min 58b7fda741 Merge branch 'master' into temp_e2e %!s(int64=4) %!d(string=hai) anos
  Min f29ac999c1 Separated custom layers to new file %!s(int64=4) %!d(string=hai) anos
  Min bbcf31634e fixed alphabet loading path %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran e04e6328a5 changed imports %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 8f2de95b94 removed plt %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 9f39857933 iterative model training %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran a75063d665 iterative learning %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 7e2c83ee4c bit-symbol mapping attemts %!s(int64=4) %!d(string=hai) anos
  Min 630a5caf79 Excluding unwanted files %!s(int64=4) %!d(string=hai) anos
  Min 810c636673 Fixed keras and tf.keras import mix %!s(int64=4) %!d(string=hai) anos
  Min 7dcdd412e1 Merge remote-tracking branch 'origin/master' %!s(int64=4) %!d(string=hai) anos
  Min 195d651bf9 Added unstructured test file %!s(int64=4) %!d(string=hai) anos
  Min a49e1fc4c3 Added QNN %!s(int64=4) %!d(string=hai) anos
  Min 8215279c3f Updated basic autoencoder %!s(int64=4) %!d(string=hai) anos
  Tharmetharan Balendran 80a5361d42 refactoring and lstm layers added %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran f56129796c fixed broken keras import %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 7e9ac81eb0 added comments %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran be6243754e Working implementation of end-to-end AE %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 25393c2ae8 AE training and testing %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 8873f6b053 initial implementation of end-to-end AE %!s(int64=5) %!d(string=hai) anos
  Min 6fa2f7d82c Merge branch 'photonics' %!s(int64=5) %!d(string=hai) anos
  Min fe3da6cd95 Quick fix signal model for autoencoder %!s(int64=5) %!d(string=hai) anos
  Min efbf971909 improved singal class compatability %!s(int64=5) %!d(string=hai) anos
  Min 2677c5e41f Merge remote-tracking branch 'origin/matched_filter' into photonics %!s(int64=5) %!d(string=hai) anos
  Min ba192057f0 Added signal definition %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 2e157f7def optical channel bugfix %!s(int64=5) %!d(string=hai) anos
  Min 294eb2b46c Train decoder to match some modulation %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 372278f5c7 encoder training %!s(int64=5) %!d(string=hai) anos
  Min 5d6c327cdf Multithreaded SNR calc + SNR for optics %!s(int64=5) %!d(string=hai) anos
  Min 7d5831b735 Merge branch 'matched_filter' %!s(int64=5) %!d(string=hai) anos
  Min b7fdda8b86 Merge remote-tracking branch 'origin/swarm_optimisation' %!s(int64=5) %!d(string=hai) anos
  Oliver Jaison 9955b46ead Function for SNR %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 6d04f26a97 ber plot for optical channel %!s(int64=5) %!d(string=hai) anos
  Tharmetharan Balendran 24307ed05e pulse shaping achieved %!s(int64=5) %!d(string=hai) anos
Modificáronse 20 ficheiros con 3414 adicións e 241 borrados
  1. 8 1
      .gitignore
  2. 5 0
      alphabets/4pam.a
  3. 33 6
      defs.py
  4. 109 0
      graphs.py
  5. 51 53
      main.py
  6. 12 1
      misc.py
  7. 211 78
      models/autoencoder.py
  8. 51 85
      models/basic.py
  9. 518 0
      models/binary_net.py
  10. 87 0
      models/data.py
  11. 617 0
      models/end_to_end.py
  12. 40 0
      models/gray_code.py
  13. 211 0
      models/layers.py
  14. 237 0
      models/new_model.py
  15. 65 15
      models/optical_channel.py
  16. 91 0
      models/plots.py
  17. 301 0
      models/quantized_net.py
  18. 597 0
      tests/min_test.py
  19. 67 2
      tests/misc_test.py
  20. 103 0
      tests/results.py

+ 8 - 1
.gitignore

@@ -4,4 +4,11 @@ __pycache__
 *.pyo
 
 # Environments
-venv/
+venv/
+
+# Anything else
+*.log
+checkpoint
+*.index
+*.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]
+other/

+ 5 - 0
alphabets/4pam.a

@@ -0,0 +1,5 @@
+R
+00,0,0
+01,0.25,0
+10,0.5,0
+11,1,0

+ 33 - 6
defs.py

@@ -1,5 +1,30 @@
 import math
 import numpy as np
+import tensorflow as tf
+
+class Signal:
+
+    @property
+    def rect_x(self) -> np.ndarray:
+        return self.rect[:, 0]
+
+    @property
+    def rect_y(self) -> np.ndarray:
+        return self.rect[:, 1]
+
+    @property
+    def rect(self) -> np.ndarray:
+        raise NotImplemented("Not implemented")
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    def set_rect(self, mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    @property
+    def apf(self):
+        raise NotImplemented("Not implemented")
 
 
 class COMComponent:
@@ -13,12 +38,14 @@ class Channel(COMComponent):
     This model is just empty therefore just bypasses any input to output
     """
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> Signal:
+        raise NotImplemented("Need to define forward function")
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
         """
-        :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
-        :return: affected tuple of (amplitude, phase, frequency)
+        Forward operation optimised for tensorflow tensors
         """
-        raise NotImplemented("Need to define forward function")
+        raise NotImplemented("Need to define forward_tensor function")
 
 
 class ModComponent(COMComponent):
@@ -30,7 +57,7 @@ class ModComponent(COMComponent):
 
 class Modulator(ModComponent):
 
-    def forward(self, binary: np.ndarray) -> np.ndarray:
+    def forward(self, binary: np.ndarray) -> Signal:
         """
         :param binary: raw bytes as input (most be dtype=bool)
         :return: amplitude, phase, frequency
@@ -40,7 +67,7 @@ class Modulator(ModComponent):
 
 class Demodulator(ModComponent):
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> np.ndarray:
         """
         :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
         :return: binary resulting values (dtype=bool)

+ 109 - 0
graphs.py

@@ -0,0 +1,109 @@
+import math
+import os
+from multiprocessing import Pool
+
+from sklearn.metrics import accuracy_score
+
+from defs import Modulator, Demodulator, Channel
+from models.basic import AWGNChannel
+from misc import generate_random_bit_array
+from models.optical_channel import OpticalChannel
+import matplotlib.pyplot as plt
+import numpy as np
+
+CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
+
+
+def show_constellation(mod: Modulator, chan: Channel, demod: Demodulator, samples=1000):
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+
+    plt.plot(x_chan.rect_x[x], x_chan.rect_y[x], '+')
+    plt.plot(x_chan.rect_x[:, 0][~x], x_chan.rect_y[:, 1][~x], '+')
+    plt.plot(x_mod.rect_x[:, 0], x_mod.rect_y[:, 1], 'ro')
+    axes = plt.gca()
+    axes.set_xlim([-2, +2])
+    axes.set_ylim([-2, +2])
+    plt.grid()
+    plt.show()
+    print('Accuracy : ' + str())
+
+
+def get_ber(mod, chan, demod, samples=1000):
+    if samples % mod.N:
+        samples += mod.N - (samples % mod.N)
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+    return 1 - accuracy_score(x, x_demod)
+
+
+def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    for noise in ber_x:
+        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
+    return ber_x, ber_y
+
+
+def __calc_ber(packed):
+    # This function has to be outside get_Optical_ber in order to be pickled by pool
+    mod, demod, noise, length, pulse_shape, samples = packed
+    tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                length=length, pulse_shape=pulse_shape, sqrt_out=True)
+    return get_ber(mod, tx_channel, demod, samples=samples)
+
+
+def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
+    with Pool(CPU_COUNT) as pool:
+        packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
+        for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
+            ber_y.append(ber)
+            i += 1  # just offset by 1
+            print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i * 100 / len(ber_x):6.2f}%)", end='')
+    print()
+    return ber_x, ber_y
+
+
+def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
+    """
+    SNR for optics and RF should be calculated the same, that is A^2
+    Because P∝V² and P∝I²
+    """
+    x_mod = mod.forward(generate_random_bit_array(samples * mod.N))
+    sig_power = [A ** 2 for A in x_mod.amplitude]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+
+    noise_start = -start + av_sig_pow
+    noise_stop = -stop + av_sig_pow
+    ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
+    SNR = -ber_x + av_sig_pow
+    return SNR, ber_y
+
+
+def show_train_history(history, title="", save=None):
+    from matplotlib import pyplot as plt
+
+    epochs = range(1, len(history.epoch) + 1)
+    if 'loss' in history.history:
+        plt.plot(epochs, history.history['loss'], label='Training Loss')
+    if 'accuracy' in history.history:
+        plt.plot(epochs, history.history['accuracy'], label='Training Accuracy')
+    if 'val_loss' in history.history:
+        plt.plot(epochs, history.history['val_loss'], label='Validation Loss')
+    if 'val_accuracy' in history.history:
+        plt.plot(epochs, history.history['val_accuracy'], label='Validation Accuracy')
+    plt.xlabel('Epochs')
+    plt.ylabel('Loss/Accuracy' if 'accuracy' in history.history else 'Loss')
+    plt.legend()
+    plt.title(title)
+    if save is not None:
+        plt.savefig(save)
+    plt.show()

+ 51 - 53
main.py

@@ -1,48 +1,7 @@
 import matplotlib.pyplot as plt
 
-import numpy as np
-from sklearn.metrics import accuracy_score
-from models import basic
+import graphs
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
-import misc
-from models.autoencoder import Autoencoder, view_encoder
-
-
-def show_constellation(mod, chan, demod, samples=1000):
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-
-    x_mod_rect = misc.polar2rect(x_mod)
-    x_chan_rect = misc.polar2rect(x_chan)
-    plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
-    plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
-    plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
-    axes = plt.gca()
-    axes.set_xlim([-2, +2])
-    axes.set_ylim([-2, +2])
-    plt.grid()
-    plt.show()
-    print('Accuracy : ' + str())
-
-
-def get_ber(mod, chan, demod, samples=1000):
-    if samples % mod.N:
-        samples += mod.N - (samples % mod.N)
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-    return 1 - accuracy_score(x, x_demod)
-
-
-def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
-    ber_x = np.linspace(start, stop, steps)
-    ber_y = []
-    for noise in ber_x:
-        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
-    return ber_x, ber_y
 
 
 if __name__ == '__main__':
@@ -110,21 +69,60 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=20e3,
-            steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    # for l in np.logspace(start=0, stop=3, num=6):
+    #     plt.plot(*misc.get_SNR(
+    #         AlphabetMod('4pam', 10e6),
+    #         AlphabetDemod('4pam', 10e6),
+    #         samples=2000,
+    #         steps=200,
+    #         start=-5,
+    #         stop=20,
+    #         length=l,
+    #         pulse_shape='rcos'
+    #     ), '-', label=(str(int(l))+'km'))
+    #
+    # plt.yscale('log')
+    # # plt.gca().invert_xaxis()
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # # plt.ylabel('BER')
+    # plt.title("BER against Fiber length")
+    # plt.legend()
+    # plt.show()
+
+    for ps in ['rect', 'rcos', 'rrcos']:
+        plt.plot(*graphs.get_SNR(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=30000,
+            steps=100,
+            start=-5,
+            stop=20,
+            length=1,
+            pulse_shape=ps
+        ), '-', label=ps)
 
     plt.yscale('log')
-    plt.gca().invert_xaxis()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.show()
 
-    pass
+    pass

+ 12 - 1
misc.py

@@ -1,7 +1,8 @@
+
 import numpy as np
 import math
 import matplotlib.pyplot as plt
-
+import pickle
 
 def display_alphabet(alphabet, values=None, a_vals=False, title="Alphabet constellation diagram"):
     rect = polar2rect(alphabet)
@@ -102,3 +103,13 @@ def generate_random_bit_array(size):
     arr = np.concatenate(p)
     np.random.shuffle(arr)
     return arr
+
+
+def picke_save(obj, fname):
+    with open(fname, 'wb') as f:
+        pickle.dump(obj, f)
+
+
+def picke_load(fname):
+    with open(fname, 'rb') as f:
+        return pickle.load(f)

+ 211 - 78
models/autoencoder.py

@@ -3,11 +3,19 @@ import numpy as np
 import tensorflow as tf
 
 from sklearn.metrics import accuracy_score
+from sklearn.model_selection import train_test_split
 from tensorflow.keras import layers, losses
 from tensorflow.keras.models import Model
-from functools import partial
+from tensorflow.python.keras.layers import LeakyReLU, ReLU
+
+# from functools import partial
 import misc
 import defs
+from models import basic
+import os
+# from tensorflow_model_optimization.python.core.quantization.keras import quantize, quantize_aware_activation
+from models.data import BinaryOneHotGenerator
+from models import layers as custom_layers
 
 latent_dim = 64
 
@@ -15,62 +23,115 @@ print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GP
 
 
 class AutoencoderMod(defs.Modulator):
-    def __init__(self, autoencoder):
-        super().__init__(2**autoencoder.N)
+    def __init__(self, autoencoder, encoder=None):
+        super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
+        self.encoder = encoder or autoencoder.encoder
 
-    def forward(self, binary: np.ndarray) -> np.ndarray:
-        reshaped = binary.reshape((-1, self.N))
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
         reshaped_ho = misc.bit_matrix2one_hot(reshaped)
-        encoded = self.autoencoder.encoder(reshaped_ho)
+        encoded = self.encoder(reshaped_ho)
         x = encoded.numpy()
-        x2 = x * 2 - 1
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
+
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
 
-        f = np.zeros(x2.shape[0])
-        x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
-        return x3
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
+        return basic.RFSignal(x3)
 
 
 class AutoencoderDemod(defs.Demodulator):
-    def __init__(self, autoencoder):
-        super().__init__(2**autoencoder.N)
+    def __init__(self, autoencoder, decoder=None):
+        super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
-
-    def forward(self, values: np.ndarray) -> np.ndarray:
-        rect = misc.polar2rect(values[:, [0, 1]])
-        decoded = self.autoencoder.decoder(rect).numpy()
-        result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
+        self.decoder = decoder or autoencoder.decoder
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+        decoded = self.decoder(val).numpy()
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
         return result.reshape(-1, )
 
 
 class Autoencoder(Model):
-    def __init__(self, N, noise):
+    def __init__(self, N, channel,
+                 signal_dim=2,
+                 parallel=1,
+                 all_onehot=True,
+                 bipolar=True,
+                 encoder=None,
+                 decoder=None,
+                 data_generator=None,
+                 cost=None
+                 ):
         super(Autoencoder, self).__init__()
         self.N = N
-        self.encoder = tf.keras.Sequential()
-        self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        # self.encoder.add(layers.Dropout(0.2))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.encoder.add(layers.Dense(units=2, activation="sigmoid"))
-        # self.encoder.add(layers.ReLU(max_value=1.0))
-
-        self.decoder = tf.keras.Sequential()
-        self.decoder.add(tf.keras.Input(shape=(2,)))
-        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
-        # self.decoder.add(layers.Dropout(0.2))
-        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
-
+        self.parallel = parallel
+        self.signal_dim = signal_dim
+        self.bipolar = bipolar
+        self._input_shape = 2 ** (N * parallel) if all_onehot else (2 ** N) * parallel
+        if encoder is None:
+            self.encoder = tf.keras.Sequential()
+            self.encoder.add(layers.Input(shape=(self._input_shape,)))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.encoder.add(layers.Dropout(0.2))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            self.encoder.add(layers.Dense(units=signal_dim * parallel, activation="sigmoid"))
+            # self.encoder.add(layers.ReLU(max_value=1.0))
+            # self.encoder = quantize.quantize_model(self.encoder)
+        else:
+            self.encoder = encoder
+
+        if decoder is None:
+            self.decoder = tf.keras.Sequential()
+            self.decoder.add(tf.keras.Input(shape=(signal_dim * parallel,)))
+            self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # leaky relu with alpha=1 gives by far best results
+            self.decoder.add(LeakyReLU(alpha=1))
+            self.decoder.add(layers.Dense(units=self._input_shape, activation="softmax"))
+        else:
+            self.decoder = decoder
         # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
 
         self.mod = None
         self.demod = None
         self.compiled = False
 
-        # Divide by 2 because encoder outputs values between 0 and 1 instead of -1 and 1
-        self.noise = 10 ** (noise / 10)  # / 2
+        self.channel = tf.keras.Sequential()
+        if self.bipolar:
+            self.channel.add(custom_layers.ScaleAndOffset(2, -1, input_shape=(signal_dim * parallel,)))
 
+        if isinstance(channel, int) or isinstance(channel, float):
+            self.channel.add(custom_layers.AwgnChannel(noise_dB=channel, input_shape=(signal_dim * parallel,)))
+        else:
+            if not isinstance(channel, tf.keras.layers.Layer):
+                raise ValueError("Channel is not a keras layer")
+            self.channel.add(channel)
+
+        self.data_generator = data_generator
+        if data_generator is None:
+            self.data_generator = BinaryOneHotGenerator
+
+        self.cost = cost
+        if cost is None:
+            self.cost = losses.MeanSquaredError()
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
 
         # [
@@ -83,35 +144,87 @@ class Autoencoder(Model):
         #     layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
         #     layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
         # ])
+    @property
+    def all_layers(self):
+        return self.encoder.layers + self.decoder.layers #self.channel.layers +
 
     def call(self, x, **kwargs):
-        encoded = self.encoder(x)
-        encoded = encoded * 2 - 1
-        # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
-        # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
-        noise = np.random.normal(0, 1, (1, 2)) * self.noise
-        noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
-        decoded = self.decoder(encoded + noisy)
-        return decoded
-
-    def train(self, samples=1e6):
-        if samples % self.N:
-            samples += self.N - (samples % self.N)
-        x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
-        x_train_ho = misc.bit_matrix2one_hot(x_train)
-
-        test_samples = samples * 0.3
-        if test_samples % self.N:
-            test_samples += self.N - (test_samples % self.N)
-        x_test_array = misc.generate_random_bit_array(test_samples)
-        x_test = x_test_array.reshape((-1, self.N))
-        x_test_ho = misc.bit_matrix2one_hot(x_test)
+        y = self.encoder(x)
+        z = self.channel(y)
+        return self.decoder(z)
+
+    def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
+        alphabet = basic.load_alphabet(modulation, polar=False)
+
+        if not alphabet.shape[0] == self.N ** 2:
+            raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
+
+        x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
+        y_train = alphabet[x_train]
+        x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
+        for idx, x in np.ndenumerate(x_train):
+            x_train_ho[idx, x] = 1
+
+        x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
+        y_test = alphabet[x_test]
+        x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
+        for idx, x in np.ndenumerate(x_test):
+            x_test_ho[idx, x] = 1
+
+        self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.encoder.fit(x_train_ho, y_train,
+                         epochs=epochs,
+                         batch_size=batch_size,
+                         shuffle=shuffle,
+                         validation_data=(x_test_ho, y_test))
+
+    def fit_decoder(self, modulation, samples):
+        samples = int(samples * 1.3)
+        demod = basic.AlphabetDemod(modulation, 0)
+        x = np.random.rand(samples, 2) * 2 - 1
+        x = x.reshape((-1, 2))
+        f = np.zeros(x.shape[0])
+        xf = np.c_[x[:, 0], x[:, 1], f]
+        y = demod.forward(basic.RFSignal(misc.rect2polar(xf)))
+        y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
+
+        X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
+        self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
+        y_pred = self.decoder(X_test).numpy()
+        y_pred2 = np.zeros(y_test.shape, dtype=bool)
+        y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
+
+        print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
+
+    def train(self, epoch_size=3e3, epochs=5, callbacks=None, optimizer='adam', metrics=None):
+        m = self.N * self.parallel
+        x_train = self.data_generator(size=epoch_size, shape=m)
+        x_test = self.data_generator(size=epoch_size*.3, shape=m)
+
+        # test_samples = epoch_size
+        # if test_samples % m:
+        #     test_samples += m - (test_samples % m)
+        # x_test_array = misc.generate_random_bit_array(test_samples)
+        # x_test = x_test_array.reshape((-1, m))
+        # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
         if not self.compiled:
-            self.compile(optimizer='adam', loss=losses.MeanSquaredError())
+            self.compile(
+                optimizer=optimizer,
+                loss=self.cost,
+                metrics=metrics
+            )
             self.compiled = True
-
-        self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
+            # self.build((self._input_shape, -1))
+            # self.summary()
+
+        history = self.fit(
+            x_train, shuffle=False,
+            validation_data=x_test, epochs=epochs,
+            callbacks=callbacks,
+        )
+        return history
         # encoded_data = self.encoder(x_test_ho)
         # decoded_data = self.decoder(encoded_data).numpy()
 
@@ -126,12 +239,14 @@ class Autoencoder(Model):
         return self.demod
 
 
-def view_encoder(encoder, N, samples=1000):
-    test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
+def view_encoder(encoder, N, samples=1000, title="Autoencoder generated alphabet"):
+    test_values = misc.generate_random_bit_array(samples*N).reshape((-1, N))
     test_values_ho = misc.bit_matrix2one_hot(test_values)
     mvector = np.array([2 ** i for i in range(N)], dtype=int)
     symbols = (test_values * mvector).sum(axis=1)
     encoded = encoder(test_values_ho).numpy()
+    if encoded.shape[1] == 1:
+        encoded = np.c_[encoded, np.zeros(encoded.shape[0])]
     # encoded = misc.polar2rect(encoded)
     for i in range(2 ** N):
         xy = encoded[symbols == i]
@@ -139,7 +254,7 @@ def view_encoder(encoder, N, samples=1000):
         plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
     plt.xlabel('Real')
     plt.ylabel('Imaginary')
-    plt.title("Autoencoder generated alphabet")
+    plt.title(title)
     # plt.legend()
     plt.show()
 
@@ -157,26 +272,44 @@ if __name__ == '__main__':
 
     n = 4
 
-    samples = 1e6
-    x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
-    x_train_ho = misc.bit_matrix2one_hot(x_train)
-    x_test_array = misc.generate_random_bit_array(samples * 0.3)
-    x_test = x_test_array.reshape((-1, n))
-    x_test_ho = misc.bit_matrix2one_hot(x_test)
+    # samples = 1e6
+    # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
+    # x_train_ho = misc.bit_matrix2one_hot(x_train)
+    # x_test_array = misc.generate_random_bit_array(samples * 0.3)
+    # x_test = x_test_array.reshape((-1, n))
+    # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
-    autoencoder = Autoencoder(n, -8)
+    autoencoder = Autoencoder(n, -15)
     autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
 
-    autoencoder.fit(x_train_ho, x_train_ho,
-                    epochs=1,
-                    shuffle=False,
-                    validation_data=(x_test_ho, x_test_ho))
+    # autoencoder.fit_encoder(modulation='16qam',
+    #                         sample_size=2e6,
+    #                         train_size=0.8,
+    #                         epochs=1,
+    #                         batch_size=256,
+    #                         shuffle=True)
 
-    encoded_data = autoencoder.encoder(x_test_ho)
-    decoded_data = autoencoder.decoder(encoded_data).numpy()
-
-    result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
-    print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
+    # view_encoder(autoencoder.encoder, n)
+    # autoencoder.fit_decoder(modulation='16qam', samples=2e6)
+    autoencoder.train()
     view_encoder(autoencoder.encoder, n)
 
+    # view_encoder(autoencoder.encoder, n)
+    # view_encoder(autoencoder.encoder, n)
+
+
+    # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
+    #
+    # autoencoder.fit(x_train_ho, x_train_ho,
+    #                 epochs=1,
+    #                 shuffle=False,
+    #                 validation_data=(x_test_ho, x_test_ho))
+    #
+    # encoded_data = autoencoder.encoder(x_test_ho)
+    # decoded_data = autoencoder.decoder(encoded_data).numpy()
+    #
+    # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
+    # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
+    # view_encoder(autoencoder.encoder, n)
+
     pass

+ 51 - 85
models/basic.py

@@ -4,78 +4,15 @@ import math
 import misc
 from scipy.spatial import cKDTree
 from os import path
+import tensorflow as tf
 
-ALPHABET_DIR = "./alphabets"
-# def _make_gray(n):
-#     if n <= 0:
-#         return []
-#     arr = ['0', '1']
-#     i = 2
-#     while True:
-#         if i >= 1 << n:
-#             break
-#         for j in range(i - 1, -1, -1):
-#             arr.append(arr[j])
-#         for j in range(i):
-#             arr[j] = "0" + arr[j]
-#         for j in range(i, 2 * i):
-#             arr[j] = "1" + arr[j]
-#         i = i << 1
-#     return list(map(lambda x: int(x, 2), arr))
-#
-#
-# def _gen_mary_alphabet(size, gray=True, polar=True):
-#     alphabet = np.zeros((size, 2))
-#     N = math.ceil(math.sqrt(size))
-#
-#     # if sqrt(size) != size^2 (not a perfect square),
-#     # skip defines how many corners to cut off.
-#     skip = 0
-#     if N ** 2 > size:
-#         skip = int(math.sqrt((N ** 2 - size) // 4))
-#
-#     step = 2 / (N - 1)
-#     skipped = 0
-#     for x in range(N):
-#         for y in range(N):
-#             i = x * N + y - skipped
-#             if i >= size:
-#                 break
-#             # Reverse y every odd column
-#             if x % 2 == 0 and N < 4:
-#                 y = N - y - 1
-#             if skip > 0:
-#                 if (x < skip or x + 1 > N - skip) and \
-#                         (y < skip or y + 1 > N - skip):
-#                     skipped += 1
-#                     continue
-#             # Exception for 3-ary alphabet, skip centre point
-#             if size == 8 and x == 1 and y == 1:
-#                 skipped += 1
-#                 continue
-#             alphabet[i, :] = [step * x - 1, step * y - 1]
-#     if gray:
-#         shape = alphabet.shape
-#         d1 = 4 if N > 4 else 2 ** N // 4
-#         g1 = np.array([0, 1, 3, 2])
-#         g2 = g1[:d1]
-#         hypershape = (d1, 4, 2)
-#         if N > 4:
-#             hypercube = alphabet.reshape(hypershape + (N-4, ))
-#             hypercube = hypercube[:, g1, :, :][g2, :, :, :]
-#         else:
-#             hypercube = alphabet.reshape(hypershape)
-#             hypercube = hypercube[:, g1, :][g2, :, :]
-#         alphabet = hypercube.reshape(shape)
-#     if polar:
-#         alphabet = misc.rect2polar(alphabet)
-#     return alphabet
+ALPHABET_DIR = path.join(path.dirname(__file__), "../alphabets")
 
 
 def load_alphabet(name, polar=True):
     apath = path.join(ALPHABET_DIR, name + '.a')
     if not path.exists(apath):
-        raise ValueError(f"Alphabet '{name}' does not exist")
+        raise ValueError(f"Alphabet '{name}' was not found in {path.abspath(apath)}")
     data = []
     indexes = []
     with open(apath, 'r') as f:
@@ -102,12 +39,12 @@ def load_alphabet(name, polar=True):
                 else:
                     raise ValueError()
                 if 'd' in header:
-                    p = y*math.pi/180
+                    p = y * math.pi / 180
                     y = math.sin(p) * x
                     x = math.cos(p) * x
                 data.append((x, y))
             except ValueError:
-                raise ValueError(f"Alphabet {name} line {i+1}: '{row}' has invalid values")
+                raise ValueError(f"Alphabet {name} line {i + 1}: '{row}' has invalid values")
 
     data2 = [None] * len(data)
     for i, d in enumerate(data):
@@ -118,6 +55,30 @@ def load_alphabet(name, polar=True):
     return arr
 
 
+class RFSignal(defs.Signal):
+    def __init__(self, array: np.ndarray):
+        self.amplitude = array[:, 0]
+        self.phase = array[:, 1]
+        self.frequency = array[:, 2]
+        self.symbols = array.shape[0]
+
+    @property
+    def rect(self) -> np.ndarray:
+        return misc.polar2rect(np.c_[self.amplitude, self.phase])
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        self.set_rect(np.c_[x_mat, y_mat])
+
+    def set_rect(self, mat: np.ndarray):
+        polar = misc.rect2polar(mat)
+        self.amplitude = polar[:, 0]
+        self.phase = polar[:, 1]
+
+    @property
+    def apf(self):
+        return np.c_[self.amplitude, self.phase, self.frequency]
+
+
 class BypassChannel(defs.Channel):
     def forward(self, values):
         return values
@@ -131,12 +92,17 @@ class AWGNChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-    def forward(self, values):
-        a = np.random.normal(0, 1, values.shape[0]) * self.noise
-        p = np.random.normal(0, 1, values.shape[0]) * self.noise
-        f = np.zeros(values.shape[0])
-        noise_mat = np.c_[a, p, f]
-        return values + noise_mat
+    def forward(self, values: RFSignal) -> RFSignal:
+        values.set_rect_xy(
+            values.rect_x + np.random.normal(0, 1, values.symbols) * self.noise,
+            values.rect_y + np.random.normal(0, 1, values.symbols) * self.noise,
+        )
+        return values
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
+        noise = tf.random.normal([2], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
+        tensor += noise * self.noise
+        return tensor
 
 
 class BPSKMod(defs.Modulator):
@@ -145,12 +111,12 @@ class BPSKMod(defs.Modulator):
         super().__init__(2, **kwargs)
         self.f = carrier_f
 
-    def forward(self, binary: np.ndarray):
+    def forward(self, binary):
         a = np.ones(binary.shape[0])
         p = np.zeros(binary.shape[0])
         p[binary == True] = np.pi
         f = np.zeros(binary.shape[0]) + self.f
-        return np.c_[a, p, f]
+        return RFSignal(np.c_[a, p, f])
 
 
 class BPSKDemod(defs.Demodulator):
@@ -167,11 +133,11 @@ class BPSKDemod(defs.Demodulator):
     def forward(self, values):
         # TODO: Channel noise simulator for frequency component?
         # for now we only care about amplitude and phase
-        ap = np.delete(values, 2, 1)
-        ap = misc.polar2rect(ap)
+        # ap = np.delete(values, 2, 1)
+        # ap = misc.polar2rect(ap)
 
-        result = np.ones(values.shape[0], dtype=bool)
-        result[ap[:, 0] > 0] = False
+        result = np.ones(values.symbols, dtype=bool)
+        result[values.rect_x[:, 0] > 0] = False
         return result
 
 
@@ -196,7 +162,7 @@ class AlphabetMod(defs.Modulator):
         a = values[:, 0]
         p = values[:, 1]
         f = np.zeros(reshaped.shape[0]) + self.f
-        return np.c_[a, p, f]  #, indices
+        return RFSignal(np.c_[a, p, f])  # , indices
 
 
 class AlphabetDemod(defs.Demodulator):
@@ -211,11 +177,11 @@ class AlphabetDemod(defs.Demodulator):
         self.ktree = cKDTree(self.alphabet)
 
     def forward(self, binary):
-        binary = binary[:, :2]  # ignore frequency
-        rbin = misc.polar2rect(binary)
-        indices = self.ktree.query(rbin)[1]
+        # binary = binary[:, :2]  # ignore frequency
+        # rbin = misc.polar2rect(binary)
+        indices = self.ktree.query(binary.rect)[1]
 
         # Converting indices to bite array
         # FIXME: unpackbits requires 8bit inputs, thus largest demodulation is 256-QAM
         values = np.unpackbits(np.array([indices], dtype=np.uint8).T, bitorder='little', axis=1)
-        return values[:, :self.N].reshape((-1,)).astype(bool)  #, indices
+        return values[:, :self.N].reshape((-1,)).astype(bool)  # , indices

+ 518 - 0
models/binary_net.py

@@ -0,0 +1,518 @@
+"""
+Adopted from https://github.com/uranusx86/BinaryNet-on-tensorflow
+
+"""
+
+# coding=UTF-8
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.python.framework import tensor_shape, ops
+from tensorflow.python.ops import standard_ops, nn, variable_scope, math_ops, control_flow_ops
+from tensorflow.python.eager import context
+from tensorflow.python.training import optimizer, training_ops
+import numpy as np
+
+# Warning: if you have a @property getter/setter function in a class, must inherit from object class
+
+all_layers = []
+
+
+def hard_sigmoid(x):
+    return tf.clip_by_value((x + 1.) / 2., 0, 1)
+
+
+def round_through(x):
+    """
+    Element-wise rounding to the closest integer with full gradient propagation.
+    A trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182)
+    a op that behave as f(x) in forward mode,
+    but as g(x) in the backward mode.
+    """
+    rounded = tf.round(x)
+    return x + tf.stop_gradient(rounded - x)
+
+
+# The neurons' activations binarization function
+# It behaves like the sign function during forward propagation
+# And like:
+#   hard_tanh(x) = 2*hard_sigmoid(x)-1
+# during back propagation
+def binary_tanh_unit(x):
+    return 2. * round_through(hard_sigmoid(x)) - 1.
+
+
+def binary_sigmoid_unit(x):
+    return round_through(hard_sigmoid(x))
+
+
+# The weights' binarization function,
+# taken directly from the BinaryConnect github repository
+# (which was made available by his authors)
+def binarization(W, H, binary=True, deterministic=False, stochastic=False, srng=None):
+    dim = W.get_shape().as_list()
+
+    # (deterministic == True) <-> test-time <-> inference-time
+    if not binary or (deterministic and stochastic):
+        # print("not binary")
+        Wb = W
+
+    else:
+        # [-1,1] -> [0,1]
+        # Wb = hard_sigmoid(W/H)
+        # Wb = T.clip(W/H,-1,1)
+
+        # Stochastic BinaryConnect
+        '''
+        if stochastic:
+            # print("stoch")
+            Wb = tf.cast(srng.binomial(n=1, p=Wb, size=tf.shape(Wb)), tf.float32)
+        '''
+
+        # Deterministic BinaryConnect (round to nearest)
+        # else:
+        # print("det")
+        # Wb = tf.round(Wb)
+
+        # 0 or 1 -> -1 or 1
+        # Wb = tf.where(tf.equal(Wb, 1.0), tf.ones_like(W), -tf.ones_like(W))  # cant differential
+        Wb = H * binary_tanh_unit(W / H)
+
+    return Wb
+
+
+class DenseBinaryLayer(keras.layers.Dense):
+    def __init__(self, output_dim,
+                 activation=None,
+                 use_bias=True,
+                 binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
+                 kernel_initializer=tf.glorot_normal_initializer(),
+                 bias_initializer=tf.zeros_initializer(),
+                 kernel_regularizer=None,
+                 bias_regularizer=None,
+                 activity_regularizer=None,
+                 kernel_constraint=None,
+                 bias_constraint=None,
+                 trainable=True,
+                 name=None,
+                 **kwargs):
+        super(DenseBinaryLayer, self).__init__(
+            units=output_dim,
+            activation=activation,
+            use_bias=use_bias,
+            kernel_initializer=kernel_initializer,
+            bias_initializer=bias_initializer,
+            kernel_regularizer=kernel_regularizer,
+            bias_regularizer=bias_regularizer,
+            activity_regularizer=activity_regularizer,
+            kernel_constraint=kernel_constraint,
+            bias_constraint=bias_constraint,
+            trainable=trainable,
+            name=name,
+            **kwargs
+        )
+
+        self.binary = binary
+        self.stochastic = stochastic
+
+        self.H = H
+        self.W_LR_scale = W_LR_scale
+
+        all_layers.append(self)
+
+    def build(self, input_shape):
+        num_inputs = tensor_shape.TensorShape(input_shape).as_list()[-1]
+        num_units = self.units
+        print(num_units)
+
+        if self.H == "Glorot":
+            self.H = np.float32(np.sqrt(1.5 / (num_inputs + num_units)))  # weight init method
+        self.W_LR_scale = np.float32(1. / np.sqrt(1.5 / (num_inputs + num_units)))  # each layer learning rate
+        print("H = ", self.H)
+        print("LR scale = ", self.W_LR_scale)
+
+        self.kernel_initializer = tf.random_uniform_initializer(-self.H, self.H)
+        self.kernel_constraint = lambda w: tf.clip_by_value(w, -self.H, self.H)
+
+        '''
+        self.b_kernel = self.add_variable('binary_weight',
+                                    shape=[input_shape[-1], self.units],
+                                    initializer=self.kernel_initializer,
+                                    regularizer=None,
+                                    constraint=None,
+                                    dtype=self.dtype,
+                                    trainable=False)  # add_variable must execute before call build()
+        '''
+        self.b_kernel = self.add_variable('binary_weight',
+                                          shape=[input_shape[-1], self.units],
+                                          initializer=tf.random_uniform_initializer(-self.H, self.H),
+                                          regularizer=None,
+                                          constraint=None,
+                                          dtype=self.dtype,
+                                          trainable=False)
+
+        super(DenseBinaryLayer, self).build(input_shape)
+
+        # tf.add_to_collection('real', self.trainable_variables)
+        # tf.add_to_collection(self.name + '_binary', self.kernel)  # layer-wise group
+        # tf.add_to_collection('binary', self.kernel)  # global group
+
+    def call(self, inputs):
+        inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
+        shape = inputs.get_shape().as_list()
+
+        # binarization weight
+        self.b_kernel = binarization(self.kernel, self.H)
+        # r_kernel = self.kernel
+        # self.kernel = self.b_kernel
+
+        print("shape: ", len(shape))
+        if len(shape) > 2:
+            # Broadcasting is required for the inputs.
+            outputs = standard_ops.tensordot(inputs, self.b_kernel, [[len(shape) - 1], [0]])
+            # Reshape the output back to the original ndim of the input.
+            if context.in_graph_mode():
+                output_shape = shape[:-1] + [self.units]
+                outputs.set_shape(output_shape)
+        else:
+            outputs = standard_ops.matmul(inputs, self.b_kernel)
+
+        # restore weight
+        # self.kernel = r_kernel
+
+        if self.use_bias:
+            outputs = nn.bias_add(outputs, self.bias)
+        if self.activation is not None:
+            return self.activation(outputs)
+        return outputs
+
+
+# Functional interface for the Dense_BinaryLayer class.
+def dense_binary(
+        inputs, units,
+        activation=None,
+        use_bias=True,
+        binary=True, stochastic=True, H=1., W_LR_scale="Glorot",
+        kernel_initializer=tf.glorot_normal_initializer(),
+        bias_initializer=tf.zeros_initializer(),
+        kernel_regularizer=None,
+        bias_regularizer=None,
+        activity_regularizer=None,
+        kernel_constraint=None,
+        bias_constraint=None,
+        trainable=True,
+        name=None,
+        reuse=None):
+    layer = DenseBinaryLayer(units,
+                             activation=activation,
+                             use_bias=use_bias,
+                             binary=binary, stochastic=stochastic, H=H, W_LR_scale=W_LR_scale,
+                             kernel_initializer=kernel_initializer,
+                             bias_initializer=bias_initializer,
+                             kernel_regularizer=kernel_regularizer,
+                             bias_regularizer=bias_regularizer,
+                             activity_regularizer=activity_regularizer,
+                             kernel_constraint=kernel_constraint,
+                             bias_constraint=bias_constraint,
+                             trainable=trainable,
+                             name=name,
+                             dtype=inputs.dtype.base_dtype,
+                             _scope=name,
+                             _reuse=reuse)
+    return layer.apply(inputs)
+
+
+# Not yet binarized
+class BatchNormalization(keras.layers.BatchNormalization):
+    def __init__(self,
+                 axis=-1,
+                 momentum=0.99,
+                 epsilon=1e-3,
+                 center=True,
+                 scale=True,
+                 beta_initializer=tf.zeros_initializer(),
+                 gamma_initializer=tf.ones_initializer(),
+                 moving_mean_initializer=tf.zeros_initializer(),
+                 moving_variance_initializer=tf.ones_initializer(),
+                 beta_regularizer=None,
+                 gamma_regularizer=None,
+                 beta_constraint=None,
+                 gamma_constraint=None,
+                 renorm=False,
+                 renorm_clipping=None,
+                 renorm_momentum=0.99,
+                 fused=None,
+                 trainable=True,
+                 name=None,
+                 **kwargs):
+        super(BatchNormalization, self).__init__(
+            axis=axis,
+            momentum=momentum,
+            epsilon=epsilon,
+            center=center,
+            scale=scale,
+            beta_initializer=beta_initializer,
+            gamma_initializer=gamma_initializer,
+            moving_mean_initializer=moving_mean_initializer,
+            moving_variance_initializer=moving_variance_initializer,
+            beta_regularizer=beta_regularizer,
+            gamma_regularizer=gamma_regularizer,
+            beta_constraint=beta_constraint,
+            gamma_constraint=gamma_constraint,
+            renorm=renorm,
+            renorm_clipping=renorm_clipping,
+            renorm_momentum=renorm_momentum,
+            fused=fused,
+            trainable=trainable,
+            name=name,
+            **kwargs)
+        # all_layers.append(self)
+
+    def build(self, input_shape):
+        super(BatchNormalization, self).build(input_shape)
+        self.W_LR_scale = np.float32(1.)
+
+
+# Functional interface for the batch normalization layer.
+def batch_normalization(
+        inputs,
+        axis=-1,
+        momentum=0.99,
+        epsilon=1e-3,
+        center=True,
+        scale=True,
+        beta_initializer=tf.zeros_initializer(),
+        gamma_initializer=tf.ones_initializer(),
+        moving_mean_initializer=tf.zeros_initializer(),
+        moving_variance_initializer=tf.ones_initializer(),
+        beta_regularizer=None,
+        gamma_regularizer=None,
+        beta_constraint=None,
+        gamma_constraint=None,
+        training=False,
+        trainable=True,
+        name=None,
+        reuse=None,
+        renorm=False,
+        renorm_clipping=None,
+        renorm_momentum=0.99,
+        fused=None):
+    layer = BatchNormalization(
+        axis=axis,
+        momentum=momentum,
+        epsilon=epsilon,
+        center=center,
+        scale=scale,
+        beta_initializer=beta_initializer,
+        gamma_initializer=gamma_initializer,
+        moving_mean_initializer=moving_mean_initializer,
+        moving_variance_initializer=moving_variance_initializer,
+        beta_regularizer=beta_regularizer,
+        gamma_regularizer=gamma_regularizer,
+        beta_constraint=beta_constraint,
+        gamma_constraint=gamma_constraint,
+        renorm=renorm,
+        renorm_clipping=renorm_clipping,
+        renorm_momentum=renorm_momentum,
+        fused=fused,
+        trainable=trainable,
+        name=name,
+        dtype=inputs.dtype.base_dtype,
+        _reuse=reuse,
+        _scope=name
+    )
+    return layer.apply(inputs, training=training)
+
+
+class AdamOptimizer(optimizer.Optimizer):
+    """Optimizer that implements the Adam algorithm.
+    See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+    ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+    """
+
+    def __init__(self, weight_scale, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
+                 use_locking=False, name="Adam"):
+        super(AdamOptimizer, self).__init__(use_locking, name)
+        self._lr = learning_rate
+        self._beta1 = beta1
+        self._beta2 = beta2
+        self._epsilon = epsilon
+
+        # BNN weight scale factor
+        self._weight_scale = weight_scale
+
+        # Tensor versions of the constructor arguments, created in _prepare().
+        self._lr_t = None
+        self._beta1_t = None
+        self._beta2_t = None
+        self._epsilon_t = None
+
+        # Variables to accumulate the powers of the beta parameters.
+        # Created in _create_slots when we know the variables to optimize.
+        self._beta1_power = None
+        self._beta2_power = None
+
+        # Created in SparseApply if needed.
+        self._updated_lr = None
+
+    def _get_beta_accumulators(self):
+        return self._beta1_power, self._beta2_power
+
+    def _non_slot_variables(self):
+        return self._get_beta_accumulators()
+
+    def _create_slots(self, var_list):
+        first_var = min(var_list, key=lambda x: x.name)
+
+        create_new = self._beta1_power is None
+        if not create_new and context.in_graph_mode():
+            create_new = (self._beta1_power.graph is not first_var.graph)
+
+        if create_new:
+            with ops.colocate_with(first_var):
+                self._beta1_power = variable_scope.variable(self._beta1,
+                                                            name="beta1_power",
+                                                            trainable=False)
+                self._beta2_power = variable_scope.variable(self._beta2,
+                                                            name="beta2_power",
+                                                            trainable=False)
+        # Create slots for the first and second moments.
+        for v in var_list:
+            self._zeros_slot(v, "m", self._name)
+            self._zeros_slot(v, "v", self._name)
+
+    def _prepare(self):
+        self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
+        self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
+        self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
+        self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+
+    def _apply_dense(self, grad, var):
+        m = self.get_slot(var, "m")
+        v = self.get_slot(var, "v")
+
+        # for BNN kernel
+        # origin version clipping weight method is new_w = old_w + scale*(new_w - old_w)
+        # and adam update function is new_w = old_w - lr_t * m_t / (sqrt(v_t) + epsilon)
+        # so subtitute adam function into weight clipping
+        # new_w = old_w - (scale * lr_t * m_t) / (sqrt(v_t) + epsilon)
+        scale = self._weight_scale[var.name] / 4
+
+        return training_ops.apply_adam(
+            var, m, v,
+            math_ops.cast(self._beta1_power, var.dtype.base_dtype),
+            math_ops.cast(self._beta2_power, var.dtype.base_dtype),
+            math_ops.cast(self._lr_t * scale, var.dtype.base_dtype),
+            math_ops.cast(self._beta1_t, var.dtype.base_dtype),
+            math_ops.cast(self._beta2_t, var.dtype.base_dtype),
+            math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
+            grad, use_locking=self._use_locking).op
+
+    def _resource_apply_dense(self, grad, var):
+        m = self.get_slot(var, "m")
+        v = self.get_slot(var, "v")
+
+        return training_ops.resource_apply_adam(
+            var.handle, m.handle, v.handle,
+            math_ops.cast(self._beta1_power, grad.dtype.base_dtype),
+            math_ops.cast(self._beta2_power, grad.dtype.base_dtype),
+            math_ops.cast(self._lr_t, grad.dtype.base_dtype),
+            math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
+            math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
+            math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
+            grad, use_locking=self._use_locking)
+
+    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
+        beta1_power = math_ops.cast(self._beta1_power, var.dtype.base_dtype)
+        beta2_power = math_ops.cast(self._beta2_power, var.dtype.base_dtype)
+        lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+        beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+        beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+        epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+        lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+        # m_t = beta1 * m + (1 - beta1) * g_t
+        m = self.get_slot(var, "m")
+        m_scaled_g_values = grad * (1 - beta1_t)
+        m_t = state_ops.assign(m, m * beta1_t,
+                               use_locking=self._use_locking)
+        with ops.control_dependencies([m_t]):
+            m_t = scatter_add(m, indices, m_scaled_g_values)
+        # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+        v = self.get_slot(var, "v")
+        v_scaled_g_values = (grad * grad) * (1 - beta2_t)
+        v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
+        with ops.control_dependencies([v_t]):
+            v_t = scatter_add(v, indices, v_scaled_g_values)
+        v_sqrt = math_ops.sqrt(v_t)
+        var_update = state_ops.assign_sub(var,
+                                          lr * m_t / (v_sqrt + epsilon_t),
+                                          use_locking=self._use_locking)
+        return control_flow_ops.group(*[var_update, m_t, v_t])
+
+    def _apply_sparse(self, grad, var):
+        return self._apply_sparse_shared(
+            grad.values, var, grad.indices,
+            lambda x, i, v: state_ops.scatter_add(  # pylint: disable=g-long-lambda
+                x, i, v, use_locking=self._use_locking))
+
+    def _resource_scatter_add(self, x, i, v):
+        with ops.control_dependencies(
+                [resource_variable_ops.resource_scatter_add(
+                    x.handle, i, v)]):
+            return x.value()
+
+    def _resource_apply_sparse(self, grad, var, indices):
+        return self._apply_sparse_shared(
+            grad, var, indices, self._resource_scatter_add)
+
+    def _finish(self, update_ops, name_scope):
+        # Update the power accumulators.
+        with ops.control_dependencies(update_ops):
+            with ops.colocate_with(self._beta1_power):
+                update_beta1 = self._beta1_power.assign(
+                    self._beta1_power * self._beta1_t,
+                    use_locking=self._use_locking)
+                update_beta2 = self._beta2_power.assign(
+                    self._beta2_power * self._beta2_t,
+                    use_locking=self._use_locking)
+        return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
+                                      name=name_scope)
+
+
+def get_all_layers():
+    return all_layers
+
+
+def get_all_LR_scale():
+    return {layer.kernel.name: layer.W_LR_scale for layer in get_all_layers()}
+
+
+# This function computes the gradient of the binary weights
+def compute_grads(loss, opt):
+    layers = get_all_layers()
+    grads_list = []
+    update_weights = []
+
+    for layer in layers:
+
+        # refer to self.params[self.W]=set(['binary'])
+        # The list can optionally be filtered by specifying tags as keyword arguments.
+        # For example,
+        # ``trainable=True`` will only return trainable parameters, and
+        # ``regularizable=True`` will only return parameters that can be regularized
+        # function return, e.g. [W, b] for dense layer
+        params = tf.get_collection(layer.name + "_binary")
+        if params:
+            # print(params[0].name)
+            # theano.grad(cost, wrt) -> d(cost)/d(wrt)
+            # wrt – with respect to which we want gradients
+            # http://blog.csdn.net/shouhuxianjian/article/details/46517143
+            # http://blog.csdn.net/qq_33232071/article/details/52806630
+            # grad = opt.compute_gradients(loss, layer.b_kernel)  # origin version
+            grad = opt.compute_gradients(loss, params[0])  # modify
+            print("grad: ", grad)
+            grads_list.append(grad[0][0])
+            update_weights.extend(params)
+
+    print(grads_list)
+    print(update_weights)
+    return zip(grads_list, update_weights)

+ 87 - 0
models/data.py

@@ -0,0 +1,87 @@
+import tensorflow as tf
+from tensorflow.keras.utils import Sequence
+from sklearn.preprocessing import OneHotEncoder
+
+import numpy as np
+
+# This creates pool of cpu resources
+# physical_devices = tf.config.experimental.list_physical_devices("CPU")
+# tf.config.experimental.set_virtual_device_configuration(
+#     physical_devices[0], [
+#         tf.config.experimental.VirtualDeviceConfiguration(),
+#         tf.config.experimental.VirtualDeviceConfiguration()
+#     ])
+import misc
+
+
+class BinaryOneHotGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = misc.bit_matrix2one_hot(x_train)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x
+
+
+class BinaryTimeDistributedOneHotGenerator(Sequence):
+    def __init__(self, size=1e5, cardinality=32, blocks=9):
+        self.size = int(size)
+        self.cardinality = cardinality
+        self.x = None
+        self.encoder = OneHotEncoder(
+            handle_unknown='ignore',
+            sparse=False,
+            categories=[np.arange(self.cardinality)]
+        )
+        self.middle = int((blocks - 1) / 2)
+        self.blocks = blocks
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        rand_int = np.random.randint(self.cardinality, size=(self.size * self.blocks, 1))
+        out = self.encoder.fit_transform(rand_int)
+        self.x = np.reshape(out, (self.size, self.blocks, self.cardinality))
+
+    def __len__(self):
+        return self.size
+
+    @property
+    def y(self):
+        return self.x[:, self.middle, :]
+
+    def __getitem__(self, idx):
+        return self.x, self.y
+
+
+class BinaryGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2, dtype=tf.bool):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.dtype = dtype
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = tf.convert_to_tensor(x, dtype=self.dtype)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x

+ 617 - 0
models/end_to_end.py

@@ -0,0 +1,617 @@
+import json
+import math
+import os
+from datetime import datetime as dt
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score
+from sklearn.preprocessing import OneHotEncoder
+from tensorflow.keras import layers, losses
+
+from models.data import BinaryTimeDistributedOneHotGenerator
+from models.layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
+import tensorflow_model_optimization as tfmot
+
+import graphs
+
+
+class EndToEndAutoencoder(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel,
+                 custom_loss_fn=False,
+                 quantize=False,
+                 alpha=1):
+        """
+        The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
+        of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
+        interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
+
+        :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        :param messages_per_block: Total number of messages in transmission block
+        :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
+        :param alpha: Alpha value for in loss function
+        """
+        super(EndToEndAutoencoder, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+        self.alpha = alpha
+
+        # Labelled N in paper - conditional +=1 to ensure odd value
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        # Channel Model Layer
+        if isinstance(channel, layers.Layer):
+            self.channel = tf.keras.Sequential([
+                layers.Flatten(),
+                channel,
+                ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+            ], name="channel_model")
+        else:
+            raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
+
+        # Boolean identifying if bit mapping is to be learnt
+        self.custom_loss_fn = custom_loss_fn
+
+        # other parameters/metrics
+        self.symbol_error_rate = None
+        self.bit_error_rate = None
+        self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
+
+        # Model Hyper-parameters
+        leaky_relu_alpha = 0
+        relu_clip_val = 1.0
+
+        # Encoding Neural Network
+        self.encoder = tf.keras.Sequential([
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
+            layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
+            layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
+            # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
+            # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
+        ], name="encoding_model")
+
+        # Decoding Neural Network
+        self.decoder = tf.keras.Sequential([
+            layers.Dense(2 * self.cardinality),
+            layers.ReLU(),
+            layers.Dense(2 * self.cardinality),
+            layers.ReLU(),
+            layers.Dense(self.cardinality, activation='softmax')
+        ], name="decoding_model")
+        self.decoder.build((1, self.samples_per_symbol))
+
+    def save_end_to_end(self, name):
+        # extract all params and save
+
+        params = {"fs": self.channel.layers[1].fs,
+                  "cardinality": self.cardinality,
+                  "samples_per_symbol": self.samples_per_symbol,
+                  "messages_per_block": self.messages_per_block,
+                  "dispersion_factor": self.channel.layers[1].dispersion_factor,
+                  "fiber_length": float(self.channel.layers[1].fiber_length),
+                  "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
+                  "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
+                  "rx_stddev": self.channel.layers[1].rx_stddev,
+                  "sig_avg": self.channel.layers[1].sig_avg,
+                  "enob": self.channel.layers[1].enob,
+                  "custom_loss_fn": self.custom_loss_fn
+                  }
+
+        if not name:
+            name = dt.utcnow().strftime("%Y%m%d-%H%M%S")
+        dir_str = os.path.join("exports", name)
+
+        if not os.path.exists(dir_str):
+            os.makedirs(dir_str)
+
+        with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
+            json.dump(params, outfile)
+
+        ################################################################################################################
+        # This section exports the weights of the encoder formatted using python variable instantiation syntax
+        ################################################################################################################
+
+        enc_weights, dec_weights = self.extract_weights()
+
+        enc_weights = [x.tolist() for x in enc_weights]
+        dec_weights = [x.tolist() for x in dec_weights]
+
+        enc_w = enc_weights[::2]
+        enc_b = enc_weights[1::2]
+
+        dec_w = dec_weights[::2]
+        dec_b = dec_weights[1::2]
+
+        with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
+            outfile.write("enc_weights = ")
+            outfile.write(str(enc_w))
+            outfile.write("\n\nenc_bias = ")
+            outfile.write(str(enc_b))
+
+        with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
+            outfile.write("dec_weights = ")
+            outfile.write(str(dec_w))
+            outfile.write("\n\ndec_bias = ")
+            outfile.write(str(dec_b))
+
+        ################################################################################################################
+
+        self.encoder.save(os.path.join(dir_str, 'encoder'))
+        self.decoder.save(os.path.join(dir_str, 'decoder'))
+
+    def extract_weights(self):
+        enc_weights = self.encoder.get_weights()
+        dec_weights = self.encoder.get_weights()
+
+        return enc_weights, dec_weights
+
+    def encode_stream(self, x):
+        enc_weights, dec_weights = self.extract_weights()
+
+        for i in range(len(enc_weights) // 2):
+            x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
+
+            if i == len(enc_weights) // 2 - 1:
+                x = tf.keras.activations.sigmoid(x).numpy()
+            else:
+                x = tf.keras.activations.relu(x).numpy()
+
+        return x
+
+    def cost(self, y_true, y_pred):
+        symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
+
+        y_bits_true = SymbolsToBits(self.cardinality)(y_true)
+        y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
+
+        bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
+        return symbol_cost + self.alpha * bit_cost
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+        A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
+
+        :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
+        consecutively to model ISI. The central message in a block is returned as the label for training.
+        :param return_vals: If true, the raw decimal values of the input sequence will be returned
+        """
+
+        cat = [np.arange(self.cardinality)]
+        enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
+        out = enc.fit_transform(rand_int)
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3, **kwargs):
+        """
+        Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
+
+        :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
+        :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
+        :param train_size: Float less than 1 representing the proportion of the dataset to use for training
+        :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
+        """
+        # X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        # X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        train_data = BinaryTimeDistributedOneHotGenerator(
+            num_of_blocks, cardinality=self.cardinality, blocks=self.messages_per_block)
+        test_data = BinaryTimeDistributedOneHotGenerator(
+            num_of_blocks * .3, cardinality=self.cardinality, blocks=self.messages_per_block)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        if self.custom_loss_fn:
+            loss_fn = self.cost
+        else:
+            loss_fn = losses.CategoricalCrossentropy()
+
+        self.compile(optimizer=opt,
+                     loss=loss_fn,
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        return self.fit(
+            train_data,
+            epochs=epochs,
+            shuffle=True,
+            validation_data=test_data,
+            **kwargs
+        )
+
+    def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True, distance=None):
+        # X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
+        test_data = BinaryTimeDistributedOneHotGenerator(
+            1000, cardinality=self.cardinality, blocks=self.messages_per_block)
+
+        num_of_blocks = int(num_of_blocks / 1000)
+        if num_of_blocks <= 0:
+            num_of_blocks = 1
+
+        ber = []
+        ser = []
+
+        for i in range(num_of_blocks):
+            y_out = self.call(test_data.x)
+
+            y_pred = tf.argmax(y_out, axis=1)
+            y_true = tf.argmax(test_data.y, axis=1)
+            ser.append(1 - accuracy_score(y_true, y_pred))
+
+            bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
+            bits_true = SymbolsToBits(self.cardinality)(test_data.y).numpy().flatten()
+            ber.append(1 - accuracy_score(bits_true, bits_pred))
+            test_data.on_epoch_end()
+            print(f"\rTested {i + 1} of {num_of_blocks} blocks", end="")
+
+        print(f"\rTested all {num_of_blocks} blocks")
+        self.symbol_error_rate = sum(ser) / len(ser)
+        self.bit_error_rate = sum(ber) / len(ber)
+
+        if length_plot:
+
+            lengths = np.linspace(0, 70, 50)
+
+            ber_l = []
+
+            for l in lengths:
+                tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
+                                            num_of_samples=self.channel.layers[1].num_of_samples,
+                                            dispersion_factor=self.channel.layers[1].dispersion_factor,
+                                            fiber_length=l,
+                                            lpf_cutoff=self.channel.layers[1].lpf_cutoff,
+                                            rx_stddev=self.channel.layers[1].rx_stddev,
+                                            sig_avg=self.channel.layers[1].sig_avg,
+                                            enob=self.channel.layers[1].enob)
+
+                test_channel = tf.keras.Sequential([
+                    layers.Flatten(),
+                    tx_channel,
+                    ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+                ], name="test channel (variable length)")
+
+                X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
+                encoded = self.encoder(X_test_l)
+                after_ch = test_channel(encoded)
+                y_out_l = self.decoder(after_ch)
+
+                y_pred_l = tf.argmax(y_out_l, axis=1)
+                # y_true_l = tf.argmax(y_test_l, axis=1)
+
+                bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
+                bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
+
+                bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
+                ber_l.append(bit_error_rate_l)
+
+            plt.plot(lengths, ber_l)
+            plt.yscale('log')
+            if plt_show:
+                plt.show()
+
+        print("SYMBOL ERROR RATE: {:e}".format(self.symbol_error_rate))
+        print("BIT ERROR RATE: {:e}".format(self.bit_error_rate))
+        return self.symbol_error_rate, self.bit_error_rate
+
+    def view_encoder(self):
+        '''
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
+        each message/symbol.
+        '''
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
+
+        # Pass input through encoder and select middle messages
+        encoded = self.encoder(messages)
+        enc_messages = encoded[:, mid_idx, :]
+
+        # Compute subplot grid layout
+        i = 0
+        while 2 ** i < self.cardinality ** 0.5:
+            i += 1
+
+        num_x = int(2 ** i)
+        num_y = int(self.cardinality / num_x)
+
+        # Plot all symbols
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
+
+        t = np.arange(self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t / self.channel.layers[1].fs
+
+        sym_idx = 0
+        for y in range(num_y):
+            for x in range(num_x):
+                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
+                axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
+                sym_idx += 1
+
+        for ax in axs.flat:
+            ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
+
+        for ax in axs.flat:
+            ax.label_outer()
+
+        plt.show()
+        pass
+
+    def view_sample_block(self):
+        '''
+        Generates a random string of input message and encodes them. In addition to this, the output is passed through
+        digitization layer without any quantization noise for the low pass filtering.
+        '''
+        # Generate a random block of messages
+        val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+        # Encode and flatten the messages
+        enc = self.encoder(inp)
+        flat_enc = layers.Flatten()(enc)
+        chan_out = self.channel.layers[1](flat_enc)
+
+        # Instantiate LPF layer
+        lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
+                                num_of_samples=self.messages_per_block * self.samples_per_symbol,
+                                sig_avg=0)
+
+        # Apply LPF
+        lpf_out = lpf(flat_enc)
+
+        a = np.fft.fft(lpf_out.numpy()).flatten()
+        f = np.fft.fftfreq(a.shape[-1]).flatten()
+
+        plt.plot(f, a)
+        plt.show()
+
+        # Time axis
+        t = np.arange(self.messages_per_block * self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t / self.channel.layers[1].fs
+
+        # Plot the concatenated symbols before and after LPF
+        plt.figure(figsize=(2 * self.messages_per_block, 6))
+
+        for i in range(1, self.messages_per_block):
+            plt.axvline(x=t[i * self.samples_per_symbol], color='black')
+
+        plt.plot(t, flat_enc.numpy().T, 'x')
+        plt.plot(t, lpf_out.numpy().T)
+        plt.plot(t, chan_out.numpy().flatten())
+        plt.ylim((0, 1))
+        plt.xlim((t.min(), t.max()))
+        plt.title(str(val[0, :, 0]))
+        plt.show()
+
+    def call(self, inputs, training=None, mask=None):
+        tx = self.encoder(inputs)
+        rx = self.channel(tx)
+        outputs = self.decoder(rx)
+        return outputs
+
+
+def load_model(model_name=None):
+    if model_name is None:
+        models = os.listdir("exports")
+        if not models:
+            raise Exception("Unable to find a trained model. Please first train and save a model.")
+        model_name = models[-1]
+
+    param_file_path = os.path.join("exports", model_name, "params.json")
+
+    if not os.path.isfile(param_file_path):
+        raise Exception("Invalid File Name/Directory")
+    else:
+        with open(param_file_path, 'r') as param_file:
+            params = json.load(param_file)
+
+    optical_channel = OpticalChannel(fs=params["fs"],
+                                     num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                                     dispersion_factor=params["dispersion_factor"],
+                                     fiber_length=params["fiber_length"],
+                                     fiber_length_stddev=params["fiber_length_stddev"],
+                                     lpf_cutoff=params["lpf_cutoff"],
+                                     rx_stddev=params["rx_stddev"],
+                                     sig_avg=params["sig_avg"],
+                                     enob=params["enob"])
+
+    ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
+                                   samples_per_symbol=params["samples_per_symbol"],
+                                   messages_per_block=params["messages_per_block"],
+                                   channel=optical_channel,
+                                   custom_loss_fn=params["custom_loss_fn"])
+
+    ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
+    ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
+
+    return ae_model, params
+
+
+def run_tests(distance=50):
+    params = {
+        "fs": 336e9,
+        "cardinality": 64,
+        "samples_per_symbol": 48,
+        "messages_per_block": 9,
+        "dispersion_factor": (-21.7 * 1e-24),
+        "fiber_length": 50,
+        "fiber_length_stddev": 1,
+        "lpf_cutoff": 32e9,
+        "rx_stddev": 0.01,
+        "sig_avg": 0.5,
+        "enob": 6,
+        "custom_loss_fn": True
+    }
+
+    force_training = True
+
+    model_save_name = f'{params["fiber_length"]}km-{params["cardinality"]}'  # "50km-64"  # "20210401-145416"
+    param_file_path = os.path.join("exports", model_save_name, "params.json")
+
+    if os.path.isfile(param_file_path) and not force_training:
+        print("Importing model {}".format(model_save_name))
+        with open(param_file_path, 'r') as file:
+            params = json.load(file)
+
+    optical_channel = OpticalChannel(
+        fs=params["fs"],
+        num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+        dispersion_factor=params["dispersion_factor"],
+        fiber_length=params["fiber_length"],
+        fiber_length_stddev=params["fiber_length_stddev"],
+        lpf_cutoff=params["lpf_cutoff"],
+        rx_stddev=params["rx_stddev"],
+        sig_avg=params["sig_avg"],
+        enob=params["enob"],
+    )
+
+    ae_model = EndToEndAutoencoder(
+        cardinality=params["cardinality"],
+        samples_per_symbol=params["samples_per_symbol"],
+        messages_per_block=params["messages_per_block"],
+        channel=optical_channel,
+        custom_loss_fn=params["custom_loss_fn"],
+        alpha=5,
+    )
+
+    checkpoint_name = f'/tmp/checkpoint/normal_{params["fiber_length"]}km'
+    model_checkpoint_callback0 = tf.keras.callbacks.ModelCheckpoint(
+        filepath=checkpoint_name,
+        save_weights_only=True,
+        monitor='val_accuracy',
+        mode='max',
+        save_best_only=True
+    )
+
+    early_stop = tf.keras.callbacks.EarlyStopping(
+        monitor='val_loss', min_delta=1e-2, patience=3, verbose=0,
+        mode='auto', baseline=None, restore_best_weights=True
+    )
+
+
+    # model_checkpoint_callback1 = tf.keras.callbacks.ModelCheckpoint(
+    #     filepath='/tmp/checkpoint/quantised',
+    #     save_weights_only=True,
+    #     monitor='val_accuracy',
+    #     mode='max',
+    #     save_best_only=True
+    # )
+
+    # if os.path.isfile(param_file_path) and not force_training:
+    #     ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
+    #     ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
+    #     print("Loaded existing model from " + model_save_name)
+    # else:
+    if not os.path.isfile(checkpoint_name + '.index'):
+        history = ae_model.train(num_of_blocks=1e3, epochs=30, callbacks=[model_checkpoint_callback0, early_stop])
+        graphs.show_train_history(history, f"Autoencoder training at {params['fiber_length']}km")
+        ae_model.save_end_to_end(model_save_name)
+
+    ae_model.load_weights(checkpoint_name)
+    ser, ber = ae_model.test(num_of_blocks=3e6)
+    data = [(params["fiber_length"], ser, ber)]
+    for l in np.linspace(params["fiber_length"] - 2.5, params["fiber_length"] + 2.5, 6):
+        optical_channel = OpticalChannel(
+            fs=params["fs"],
+            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+            dispersion_factor=params["dispersion_factor"],
+            fiber_length=l,
+            fiber_length_stddev=params["fiber_length_stddev"],
+            lpf_cutoff=params["lpf_cutoff"],
+            rx_stddev=params["rx_stddev"],
+            sig_avg=params["sig_avg"],
+            enob=params["enob"],
+        )
+        ae_model = EndToEndAutoencoder(
+            cardinality=params["cardinality"],
+            samples_per_symbol=params["samples_per_symbol"],
+            messages_per_block=params["messages_per_block"],
+            channel=optical_channel,
+            custom_loss_fn=params["custom_loss_fn"],
+            alpha=5,
+        )
+        ae_model.load_weights(checkpoint_name)
+        print(f"Testing {l}km")
+        ser, ber = ae_model.test(num_of_blocks=3e6)
+        data.append((l, ser, ber))
+    return data
+
+
+if __name__ == '__main__':
+    data0 = run_tests(90)
+    # data1 = run_tests(70)
+    # data2 = run_tests(80)
+    # print('Results 60: ', data0)
+    # print('Results 70: ', data1)
+    print('Results 90: ', data0)
+
+    # ae_model.test(num_of_blocks=3e6)
+    # ae_model.load_weights('/tmp/checkpoint/normal')
+
+    #
+    # quantize_model = tfmot.quantization.keras.quantize_model
+    # ae_model.decoder = quantize_model(ae_model.decoder)
+    #
+    # # ae_model.load_weights('/tmp/checkpoint/quantised')
+    #
+    # history = ae_model.train(num_of_blocks=1e3, epochs=20, callbacks=[model_checkpoint_callback1])
+    # graphs.show_train_history(history, f"Autoencoder quantised finetune at {params['fiber_length']}km")
+
+    # SYMBOL ERROR RATE: 2.039667e-03
+    #                    2.358000e-03
+    # BIT ERROR RATE: 4.646000e-04
+    #                 6.916000e-04
+
+    # SYMBOL ERROR RATE: 4.146667e-04
+    # BIT ERROR RATE: 1.642667e-04
+    # ae_model.save_end_to_end("50km-q3+")
+    # ae_model.test(num_of_blocks=3e6)
+
+    # Fibre, SER, BER
+    # 50, 2.233333e-05, 5.000000e-06
+    # 60, 6.556667e-04, 1.343333e-04
+    # 75, 1.570333e-03, 3.144667e-04
+    ## 80, 8.061667e-03, 1.612333e-03
+    # 85, 7.811333e-03, 1.601600e-03
+    # 90, 1.121933e-02, 2.255200e-03
+    ## 90, 1.266433e-02, 2.767467e-03
+
+    # 64 cardinality
+    # 50, 5.488000e-03, 1.089000e-03
+    pass

+ 40 - 0
models/gray_code.py

@@ -0,0 +1,40 @@
+from scipy.spatial import Voronoi, voronoi_plot_2d
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def get_gray_code(n: int):
+    return n ^ (n >> 1)
+
+
+def difference(sym0: int, sym1: int):
+    return bit_count(sym0 ^ sym1)
+
+
+def bit_count(i: int):
+    """
+    Hamming weight algorithm, just counts number of 1s
+    """
+    assert 0 <= i < 0x100000000
+    i = i - ((i >> 1) & 0x55555555)
+    i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
+    return (((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) & 0xffffffff) >> 24
+
+
+def compute_optimal(points, show_graph=False):
+    available = set(range(len(points)))
+    map = {}
+
+    vor = Voronoi(points)
+
+    if show_graph:
+        voronoi_plot_2d(vor)
+        plt.show()
+    pass
+
+
+if __name__ == '__main__':
+    a = np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]])
+    # a = basic.load_alphabet('16qam', polar=False)
+    compute_optimal(a, show_graph=True)
+    pass

+ 211 - 0
models/layers.py

@@ -0,0 +1,211 @@
+"""
+Custom Keras Layers for general use
+"""
+import itertools
+
+from tensorflow.keras import layers
+import tensorflow as tf
+import numpy as np
+import math
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
+        """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__(**kwargs)
+        if noise_dB is not None:
+            # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
+            rx_stddev = 10 ** (noise_dB / 10.0)
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class BitsToSymbols(layers.Layer):
+    def __init__(self, cardinality, messages_per_block):
+        super(BitsToSymbols, self).__init__()
+
+        self.cardinality = cardinality
+        self.messages_per_block = messages_per_block
+
+        n = int(math.log(self.cardinality, 2))
+        self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
+        out = tf.one_hot(idx, self.cardinality)
+        return layers.Reshape((self.messages_per_block, self.cardinality))(out)
+
+
+class SymbolsToBits(layers.Layer):
+    def __init__(self, cardinality):
+        super(SymbolsToBits, self).__init__()
+
+        n = int(math.log(cardinality, 2))
+        lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
+
+        # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+        self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.all_syms)
+
+
+class ScaleAndOffset(layers.Layer):
+    """
+    Scales and offsets a tensor
+    """
+
+    def __init__(self, scale=1, offset=0, **kwargs):
+        super(ScaleAndOffset, self).__init__(**kwargs)
+        self.offset = offset
+        self.scale = scale
+
+    def call(self, inputs, **kwargs):
+        return inputs * self.scale + self.offset
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.samples_per_symbol = samples_per_symbol
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10))
+
+        self.noise_layer = layers.GaussianNoise(stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 fiber_length_stddev=0,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 sig_avg=0.5,
+                 enob=10):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param sig_avg: Average signal amplitude
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.fs = fs
+        self.num_of_samples = num_of_samples
+        self.dispersion_factor = dispersion_factor
+        self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
+        self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
+        self.lpf_cutoff = lpf_cutoff
+        self.rx_stddev = rx_stddev
+        self.sig_avg = sig_avg
+        self.enob = enob
+
+        self.noise_layer = layers.GaussianNoise(self.rx_stddev)
+        self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
+        self.digitization_layer = DigitizationLayer(fs=self.fs,
+                                                    num_of_samples=self.num_of_samples,
+                                                    lpf_cutoff=self.lpf_cutoff,
+                                                    sig_avg=self.sig_avg,
+                                                    enob=self.enob)
+        self.flatten_layer = layers.Flatten()
+
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
+        # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_val)
+
+        len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
+
+        multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
+
+        disp_f = tf.math.multiply(val_f, multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out

+ 237 - 0
models/new_model.py

@@ -0,0 +1,237 @@
+import tensorflow as tf
+from tensorflow.keras import layers, losses
+from models.layers import ExtractCentralMessage, OpticalChannel
+from models.end_to_end import EndToEndAutoencoder
+from models.layers import BitsToSymbols, SymbolsToBits
+import numpy as np
+import math
+
+from matplotlib import pyplot as plt
+
+
+class BitMappingModel(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        super(BitMappingModel, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        self.bits_per_symbol = int(math.log(self.cardinality, 2))
+
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+
+        self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
+                                             samples_per_symbol=self.samples_per_symbol,
+                                             messages_per_block=self.messages_per_block,
+                                             channel=channel,
+                                             custom_loss_fn=False)
+
+        self.bit_error_rate = []
+        self.symbol_error_rate = []
+
+    def call(self, inputs, training=None, mask=None):
+        x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
+        x2 = self.e2e_model(x1)
+        out = SymbolsToBits(self.cardinality)(x2)
+        return out
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+
+        """
+
+        mid_idx = int((self.messages_per_block - 1) / 2)
+
+        rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
+
+        out = rand_int
+
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
+
+        if return_vals:
+            return out_arr, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+        X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+        X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.MeanSquaredError(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=epochs,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        for i in range(int(iters)):
+            print("Loop {}/{}".format(i, iters))
+
+            self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+            X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
+            X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
+
+            X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
+            X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
+
+            opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+            self.compile(optimizer=opt,
+                         loss=losses.BinaryCrossentropy(),
+                         metrics=['accuracy'],
+                         loss_weights=None,
+                         weighted_metrics=None,
+                         run_eagerly=False
+                         )
+
+            self.fit(x=X_train,
+                     y=y_train,
+                     batch_size=batch_size,
+                     epochs=epochs,
+                     shuffle=True,
+                     validation_data=(X_test, y_test)
+                     )
+
+            self.e2e_model.test()
+            self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
+            self.bit_error_rate.append(self.e2e_model.bit_error_rate)
+
+
+SAMPLING_FREQUENCY = 336e9
+CARDINALITY = 64
+SAMPLES_PER_SYMBOL = 48
+MESSAGES_PER_BLOCK = 11
+DISPERSION_FACTOR = -21.7 * 1e-24
+FIBER_LENGTH = 50
+ENOB = 6
+
+if __name__ == 'asd':
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     sig_avg=0.5,
+                                     enob=ENOB)
+
+    model = BitMappingModel(cardinality=CARDINALITY,
+                            samples_per_symbol=SAMPLES_PER_SYMBOL,
+                            messages_per_block=MESSAGES_PER_BLOCK,
+                            channel=optical_channel)
+
+    model.train()
+
+if __name__ == '__main__':
+
+    distances = [50]
+    ser = []
+    ber = []
+
+    baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
+    bit_rate = math.log(CARDINALITY, 2) * baud_rate
+    snr = None
+
+    for d in distances:
+
+        optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                         num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
+                                         dispersion_factor=DISPERSION_FACTOR,
+                                         fiber_length=d,
+                                         sig_avg=0.5,
+                                         enob=ENOB)
+
+        model = BitMappingModel(cardinality=CARDINALITY,
+                                samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                messages_per_block=MESSAGES_PER_BLOCK,
+                                channel=optical_channel)
+
+        if snr is None:
+            snr = model.e2e_model.snr
+        elif snr != model.e2e_model.snr:
+            print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
+
+        print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+
+        model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
+
+        model.e2e_model.test(length_plot=True)
+
+        ber.append(model.bit_error_rate[-1])
+        ser.append(model.symbol_error_rate[-1])
+
+        e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                        samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                        messages_per_block=MESSAGES_PER_BLOCK,
+                                        channel=optical_channel,
+                                        custom_loss_fn=False)
+
+        ber1 = []
+        ser1 = []
+
+        for i in range(int(len(model.bit_error_rate))):
+            e2e_model.train(num_of_blocks=1e3, epochs=5)
+            e2e_model.test()
+
+            ber1.append(e2e_model.bit_error_rate)
+            ser1.append(e2e_model.symbol_error_rate)
+
+        # model2 = BitMappingModel(cardinality=CARDINALITY,
+        #                          samples_per_symbol=SAMPLES_PER_SYMBOL,
+        #                          messages_per_block=MESSAGES_PER_BLOCK,
+        #                          channel=optical_channel)
+        #
+        # ber2 = []
+        # ser2 = []
+        #
+        # for i in range(int(len(model.bit_error_rate) / 2)):
+        #     model2.train(num_of_blocks=1e3, epochs=5)
+        #     model2.e2e_model.test()
+        #
+        #     ber2.append(model2.e2e_model.bit_error_rate)
+        #     ser2.append(model2.e2e_model.symbol_error_rate)
+
+        plt.plot(ber1, label='BER (1)')
+        # plt.plot(ser1, label='SER (1)')
+        # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
+        # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
+        plt.plot(model.bit_error_rate, label='BER (3)')
+        # plt.plot(model.symbol_error_rate, label='SER (3)')
+
+        plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
+        plt.yscale('log')
+        plt.legend()
+        plt.show()
+        # model.summary()
+
+    # plt.plot(ber, label='BER')
+    # plt.plot(ser, label='SER')
+    # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
+    # plt.legend(ber)

+ 65 - 15
models/optical_channel.py

@@ -3,17 +3,23 @@ import matplotlib.pyplot as plt
 import defs
 import numpy as np
 import math
-from scipy.fft import fft, ifft
+from numpy.fft import fft, fftfreq, ifft
+from commpy.filters import rrcosfilter, rcosfilter, rectfilter
+
+from models import basic
 
 
 class OpticalChannel(defs.Channel):
-    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, show_graphs=False, **kwargs):
+    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
+                 sqrt_out=False, show_graphs=False, **kwargs):
         """
         :param noise_level: Noise level in dB
         :param dispersion: dispersion coefficient is ps^2/km
         :param symbol_rate: Symbol rate of modulated signal in Hz
         :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
         :param length: fibre length in km
+        :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
+        :param sqrt_out: Take the root of the out to compensate for photodiode detection
         :param show_graphs: if graphs should be displayed or not
 
         Optical Channel class constructor
@@ -21,29 +27,45 @@ class OpticalChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-        self.dispersion = dispersion # * 1e-24  # Converting from ps^2/km to s^2/km
+        self.dispersion = dispersion * 1e-24  # Converting from ps^2/km to s^2/km
         self.symbol_rate = symbol_rate
         self.symbol_period = 1 / self.symbol_rate
         self.sample_rate = sample_rate
         self.sample_period = 1 / self.sample_rate
         self.length = length
+        self.pulse_shape = pulse_shape.strip().lower()
+        self.sqrt_out = sqrt_out
         self.show_graphs = show_graphs
 
     def __get_time_domain(self, symbol_vals):
         samples_per_symbol = int(self.sample_rate / self.symbol_rate)
         samples = int(symbol_vals.shape[0] * samples_per_symbol)
 
-        symbol_vals_a = np.repeat(symbol_vals, repeats=samples_per_symbol, axis=0)
-        t = np.linspace(start=0, stop=samples * self.sample_period, num=samples)
-        val_t = symbol_vals_a[:, 0] * np.cos(2 * math.pi * symbol_vals_a[:, 2] * t + symbol_vals_a[:, 1])
+        symbol_impulse = np.zeros(samples)
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        for i in range(symbol_vals.shape[0]):
+            symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
+
+        if self.pulse_shape == 'rrcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        elif self.pulse_shape == 'rcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        else:
+            self.filter_samples = samples_per_symbol
+            self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
+
+        val_t = np.convolve(symbol_impulse, self.h_filter)
+        t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
 
         return t, val_t
 
     def __time_to_frequency(self, values):
         val_f = fft(values)
-        f = np.linspace(0.0, 1 / (2 * self.sample_period), (values.size // 2))
-        f_neg = -1 * np.flip(f)
-        f = np.concatenate((f, f_neg), axis=0)
+        f = fftfreq(values.shape[-1])*self.sample_rate
         return f, val_f
 
     def __frequency_to_time(self, values):
@@ -79,6 +101,8 @@ class OpticalChannel(defs.Channel):
         return t, val_t
 
     def forward(self, values):
+        if hasattr(values, 'apf'):
+            values = values.apf
         # Converting APF representation to time-series
         t, val_t = self.__get_time_domain(values)
 
@@ -106,22 +130,48 @@ class OpticalChannel(defs.Channel):
         # Photodiode Detection
         t, val_t = self.__photodiode_detection(val_t)
 
+        # Symbol Decisions
+        idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
+                        self.symbol_period/self.sample_period, dtype='int64')
+        t_descision = self.sample_period * idx
+
         if self.show_graphs:
             plt.plot(t, val_t)
             plt.title('time domain (post-detection)')
             plt.show()
 
-        return t, val_t
+            plt.plot(t, val_t)
+            for xc in t_descision:
+                plt.axvline(x=xc, color='r')
+            plt.title('time domain (post-detection with decision times)')
+            plt.show()
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        out = np.zeros(values.shape)
+
+        out[:, 0] = val_t[idx]
+        out[:, 1] = values[:, 1]
+        out[:, 2] = values[:, 2]
+
+        if self.sqrt_out:
+            out[:, 0] = np.sqrt(out[:, 0])
+
+        return basic.RFSignal(out)
 
 
 if __name__ == '__main__':
     # Simple OOK modulation
-    num_of_symbols = 10
+    num_of_symbols = 100
     symbol_vals = np.zeros((num_of_symbols, 3))
 
     symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
-    symbol_vals[:, 2] = 10e6
+    symbol_vals[:, 2] = 40e9
+
+    channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
+                             sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
+    v = channel.forward(symbol_vals)
 
-    channel = OpticalChannel(noise_level=-20, dispersion=-21.7, symbol_rate=100e3,
-                             sample_rate=500e6, length=100, show_graphs=True)
-    time, v = channel.forward(symbol_vals)
+    rx = (v > 0.5).astype(int)
+    tru = np.sum(rx == symbol_vals[:, 0].astype(int))
+    print("Accuracy: {}".format(tru/num_of_symbols))

+ 91 - 0
models/plots.py

@@ -0,0 +1,91 @@
+from sklearn.preprocessing import OneHotEncoder
+import numpy as np
+from tensorflow.keras import layers
+
+from end_to_end import load_model
+from models.layers import DigitizationLayer, OpticalChannel
+from matplotlib import pyplot as plt
+import math
+
+# plot frequency spectrum of e2e model
+def plot_e2e_spectrum(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a list of random symbols (one hot encoded)
+    cat = [np.arange(params["cardinality"])]
+    enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+    rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
+    out = enc.fit_transform(rand_int)
+
+    # Encode the list of symbols using the trained encoder
+    a = ae_model.encode_stream(out).flatten()
+
+    # Pass the output of the encoder through LPF
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=320000,
+                            sig_avg=0)(a).numpy()
+
+    # Plot the frequency spectrum of the signal
+    freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
+    mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
+
+    a = np.fft.ifft(mul)
+    a2 = np.power(a, 2)
+    b = np.abs(np.fft.fft(a2))
+
+
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+    # plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.plot(freq, b)
+    plt.ylim((-500, 500))
+    plt.xlim((-5e10, 5e10))
+    plt.show()
+
+def plot_e2e_encoded_output(model_name=None):
+    # Load pre-trained model
+    ae_model, params = load_model(model_name=model_name)
+
+    # Generate a random block of messages
+    val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+    # Encode and flatten the messages
+    enc = ae_model.encoder(inp)
+    flat_enc = layers.Flatten()(enc)
+    chan_out = ae_model.channel.layers[1](flat_enc)
+
+    # Instantiate LPF layer
+    lpf = DigitizationLayer(fs=params["fs"],
+                            num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
+                            sig_avg=0)
+
+    # Apply LPF
+    lpf_out = lpf(flat_enc)
+
+    # Time axis
+    t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
+    if isinstance(ae_model.channel.layers[1], OpticalChannel):
+        t = t / params["fs"]
+
+    # Plot the concatenated symbols before and after LPF
+    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+
+    for i in range(1, params["messages_per_block"]):
+        plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
+    plt.axhline(y=0, color='black')
+    plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.ylim((-0.1, 0.1))
+    plt.xlim((t.min(), t.max()))
+    plt.title(str(val[0, :, 0]))
+    plt.legend(loc='upper right')
+    plt.show()
+
+if __name__ == '__main__':
+    # plot_e2e_spectrum()
+    plot_e2e_encoded_output()

+ 301 - 0
models/quantized_net.py

@@ -0,0 +1,301 @@
+from numpy import (
+    array,
+    zeros,
+    dot,
+    median,
+    log2,
+    linspace,
+    argmin,
+    abs,
+)
+from scipy.linalg import norm
+from tensorflow.keras.backend import function as Kfunction
+from tensorflow.keras.models import Model, clone_model
+from collections import namedtuple
+from typing import List, Generator
+from time import time
+
+from models.autoencoder import Autoencoder
+
+QuantizedNeuron = namedtuple("QuantizedNeuron", ["layer_idx", "neuron_idx", "q"])
+QuantizedFilter = namedtuple(
+    "QuantizedFilter", ["layer_idx", "filter_idx", "channel_idx", "q_filtr"]
+)
+SegmentedData = namedtuple("SegmentedData", ["wX_seg", "qX_seg"])
+
+
+class QuantizedNeuralNetwork:
+    def __init__(
+            self,
+            network: Model,
+            batch_size: int,
+            get_data: Generator[array, None, None],
+            logger=None,
+            ignore_layers=[],
+            bits=log2(3),
+            alphabet_scalar=1,
+    ):
+
+        self.get_data = get_data
+
+        # The pre-trained network.
+        self.trained_net = network
+
+        # This copies the network structure but not the weights.
+        if isinstance(network, Autoencoder):
+            # The pre-trained network.
+            self.trained_net_layers = network.all_layers
+            self.quantized_net = Autoencoder(network.N, network.channel, bipolar=network.bipolar)
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.all_layers
+            # self.quantized_net.layers = self.quantized_net.layers
+        else:
+            # The pre-trained network.
+            self.trained_net_layers = network.layers
+            self.quantized_net = clone_model(network)
+            # Set all the weights to be the same a priori.
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.layers
+
+        self.batch_size = batch_size
+
+        self.alphabet_scalar = alphabet_scalar
+
+        # Create a dictionary encoding which layers are Dense, and what their dimensions are.
+        self.layer_dims = {
+            layer_idx: layer.get_weights()[0].shape
+            for layer_idx, layer in enumerate(network.layers)
+            if layer.__class__.__name__ == "Dense"
+        }
+
+        # This determines the alphabet. There will be 2**bits atoms in our alphabet.
+        self.bits = bits
+
+        # Construct the (unscaled) alphabet. Layers will scale this alphabet based on the
+        # distribution of that layer's weights.
+        self.alphabet = linspace(-1, 1, num=int(round(2 ** (bits))))
+
+        self.logger = logger
+
+        self.ignore_layers = ignore_layers
+
+    def _log(self, msg: str):
+        if self.logger:
+            self.logger.info(msg)
+        else:
+            print(msg)
+
+    def _bit_round(self, t: float, rad: float) -> float:
+        """Rounds a quantity to the nearest atom in the (scaled) quantization alphabet.
+
+        Parameters
+        -----------
+        t : float
+            The value to quantize.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        # Scale the alphabet appropriately.
+        layer_alphabet = rad * self.alphabet
+        return layer_alphabet[argmin(abs(layer_alphabet - t))]
+
+    def _quantize_weight(
+            self, w: float, u: array, X: array, X_tilde: array, rad: float
+    ) -> float:
+        """Quantizes a single weight of a neuron.
+
+        Parameters
+        -----------
+        w : float
+            The weight.
+        u : array ,
+            Residual vector.
+        X : array
+            Vector from the analog network's random walk.
+        X_tilde : array
+            Vector from the quantized network's random walk.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        if norm(X_tilde, 2) < 10 ** (-16):
+            return 0
+
+        if abs(dot(X_tilde, u)) < 10 ** (-10):
+            return self._bit_round(w, rad)
+
+        return self._bit_round(dot(X_tilde, u + w * X) / (norm(X_tilde, 2) ** 2), rad)
+
+    def _quantize_neuron(
+            self,
+            layer_idx: int,
+            neuron_idx: int,
+            wX: array,
+            qX: array,
+            rad=1,
+    ) -> QuantizedNeuron:
+        """Quantizes a single neuron in a Dense layer.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        neuron_idx : int,
+            Index of the neuron in the Dense layer.
+        wX : array
+            Layer input for the analog convolutional neural network.
+        qX : array
+            Layer input for the quantized convolutional neural network.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        QuantizedNeuron: NamedTuple
+            A tuple with the layer and neuron index, as well as the quantized neuron.
+        """
+
+        N_ell = wX.shape[1]
+        u = zeros(self.batch_size)
+        w = self.trained_net_layers[layer_idx].get_weights()[0][:, neuron_idx]
+        q = zeros(N_ell)
+        for t in range(N_ell):
+            q[t] = self._quantize_weight(w[t], u, wX[:, t], qX[:, t], rad)
+            u += w[t] * wX[:, t] - q[t] * qX[:, t]
+
+        return QuantizedNeuron(layer_idx=layer_idx, neuron_idx=neuron_idx, q=q)
+
+    def _get_layer_data(self, layer_idx: int, hf=None):
+        """Gets the input data for the layer at a given index.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the layer.
+        hf: hdf5 File object in write mode.
+            If provided, will write output to hdf5 file instead of returning directly.
+
+        Returns
+        -------
+        tuple: (array, array)
+            A tuple of arrays, with the first entry being the input for the analog network
+            and the latter being the input for the quantized network.
+        """
+
+        layer = self.trained_net_layers[layer_idx]
+        layer_data_shape = layer.input_shape[1:] if layer.input_shape[0] is None else layer.input_shape
+        wX = zeros((self.batch_size, *layer_data_shape))
+        qX = zeros((self.batch_size, *layer_data_shape))
+        if layer_idx == 0:
+            for sample_idx in range(self.batch_size):
+                try:
+                    wX[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+            qX = wX
+        else:
+            # Define functions which will give you the output of the previous hidden layer
+            # for both networks.
+            # try:
+            prev_trained_output = Kfunction(
+                [self.trained_net_layers[0].input],
+                [self.trained_net_layers[layer_idx - 1].output],
+            )
+            prev_quant_output = Kfunction(
+                [self.quantized_net_layers[0].input],
+                [self.quantized_net_layers[layer_idx - 1].output],
+            )
+            input_layer = self.trained_net_layers[0]
+            input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
+            batch = zeros((self.batch_size, *input_shape))
+            # except Exception:
+            #     pass
+            # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
+            # to reconsider how much memory you preallocate for batch, wX, and qX.
+            feed_foward_batch_size = 500
+            ctr = 0
+            for sample_idx in range(self.batch_size):
+                try:
+                    batch[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+
+            wX = prev_trained_output([batch])[0]
+            qX = prev_quant_output([batch])[0]
+
+        return (wX, qX)
+
+    def _update_weights(self, layer_idx: int, Q: array):
+        """Updates the weights of the quantized neural network given a layer index and
+        quantized weights.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Conv2D layer.
+        Q : array
+            The quantized weights.
+        """
+
+        # Update the quantized network. Use the same bias vector as in the analog network for now.
+        if self.trained_net_layers[layer_idx].use_bias:
+            bias = self.trained_net_layers[layer_idx].get_weights()[1]
+            self.quantized_net_layers[layer_idx].set_weights([Q, bias])
+        else:
+            self.quantized_net_layers[layer_idx].set_weights([Q])
+
+    def _quantize_layer(self, layer_idx: int):
+        """Quantizes a Dense layer of a multi-layer perceptron.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        """
+
+        W = self.trained_net_layers[layer_idx].get_weights()[0]
+        N_ell, N_ell_plus_1 = W.shape
+        # Placeholder for the weight matrix in the quantized network.
+        Q = zeros(W.shape)
+        N_ell_plus_1 = W.shape[1]
+        wX, qX = self._get_layer_data(layer_idx)
+
+        # Set the radius of the alphabet.
+        rad = self.alphabet_scalar * median(abs(W.flatten()))
+
+        for neuron_idx in range(N_ell_plus_1):
+            # self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            tic = time()
+            qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
+            Q[:, neuron_idx] = qNeuron.q
+
+            # self._log(f"\tdone. {time() - tic :.2f} seconds.")
+
+            self._update_weights(layer_idx, Q)
+
+    def quantize_network(self):
+        """Quantizes all Dense layers that are not specified by the list of ignored layers."""
+
+        # This must be done sequentially.
+        for layer_idx, layer in enumerate(self.trained_net_layers):
+            if (
+                    layer.__class__.__name__ == "Dense"
+                    and layer_idx not in self.ignore_layers
+            ):
+                # Only quantize dense layers.
+                self._log(f"Quantizing layer {layer_idx}...")
+                self._quantize_layer(layer_idx)
+                self._log(f"done #{layer_idx} {layer}...")

+ 597 - 0
tests/min_test.py

@@ -0,0 +1,597 @@
+"""
+These are some unstructured tests. Feel free to use this code for anything else
+"""
+
+import logging
+import pathlib
+from itertools import chain
+from sys import stdout
+
+from tensorflow.python.framework.errors_impl import NotFoundError
+from tensorflow.keras import backend as K
+
+import defs
+from graphs import get_SNR, get_AWGN_ber, show_train_history
+from models import basic
+from models.autoencoder import Autoencoder, view_encoder
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import misc
+import numpy as np
+
+from models.basic import AlphabetDemod, AlphabetMod
+from models.data import BinaryGenerator
+from models.layers import BitsToSymbols, SymbolsToBits
+from models.optical_channel import OpticalChannel
+from models.quantized_net import QuantizedNeuralNetwork
+
+
+def _test_optics_autoencoder():
+    ch = OpticalChannel(
+        noise_level=-10,
+        dispersion=-21.7,
+        symbol_rate=10e9,
+        sample_rate=400e9,
+        length=10,
+        pulse_shape='rcos',
+        sqrt_out=True
+    )
+
+    tf.executing_eagerly()
+
+    aenc = Autoencoder(4, channel=ch)
+    aenc.train(samples=1e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('4pam', 10e6),
+        AlphabetDemod('4pam', 10e6),
+        samples=30000,
+        steps=50,
+        start=-5,
+        stop=15,
+        length=1,
+        pulse_shape='rcos'
+    ), '-', label='4PAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder Performance")
+    plt.legend()
+    plt.savefig('optics_autoencoder.eps', format='eps')
+    plt.show()
+
+
+def _test_autoencoder_pretrain():
+    # aenc = Autoencoder(4, -25)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='Random AE')
+
+    aenc = Autoencoder(4, -25)
+    # aenc.fit_encoder('16qam', 3e4)
+    aenc.fit_decoder('16qam', 1e5)
+
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Pre-trained AE')
+
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Post-trained AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('16qam', 10e6),
+        AlphabetDemod('16qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("4Bit Autoencoder Performance")
+    plt.legend()
+    plt.show()
+
+
+class LiteTFMod(defs.Modulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+        pass
+
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
+        reshaped_ho = misc.bit_matrix2one_hot(reshaped)
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        x = np.zeros((len(reshaped_ho), output_shape[1]))
+        for i, ho in enumerate(reshaped_ho):
+            self.interpreter.set_tensor(input_index, ho.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            x[i] = self.interpreter.get_tensor(output_index)
+
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
+
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
+
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
+        return basic.RFSignal(x3)
+
+
+class LiteTFDemod(defs.Demodulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        decoded = np.zeros((len(val), output_shape[1]))
+        for i, v in enumerate(val):
+            self.interpreter.set_tensor(input_index, v.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            decoded[i] = self.interpreter.get_tensor(output_index)
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
+        return result.reshape(-1, )
+
+
+def _test_autoencoder_perf():
+    assert float(tf.__version__[:3]) >= 2.3
+
+    # aenc = Autoencoder(3, -15)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='3Bit AE')
+
+    # aenc = Autoencoder(4, -25, bipolar=True, dtype=tf.float64)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4Bit AE F64')
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    aenc.train(epoch_size=1e3, epochs=10)
+    # #
+    m = aenc.N * aenc.parallel
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100 * m).reshape((-1, m)))
+    x_train_enc = aenc.encoder(x_train)
+    x_train = tf.cast(x_train, tf.float32)
+
+    #
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F32')
+    # # #
+    def save_tfline(model, name, types=None, ops=None, io_types=None, train_x=None):
+        converter = tf.lite.TFLiteConverter.from_keras_model(model)
+        if types is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_types = types
+        if ops is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_ops = ops
+        if io_types is not None:
+            converter.inference_input_type = io_types
+            converter.inference_output_type = io_types
+        if train_x is not None:
+            def representative_data_gen():
+                for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
+                    yield [input_value]
+
+            converter.representative_dataset = representative_data_gen
+        tflite_model = converter.convert()
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_models_dir.mkdir(exist_ok=True, parents=True)
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        tflite_model_file.write_bytes(tflite_model)
+
+    print("Saving models")
+
+    save_tfline(aenc.encoder, "default_enc")
+    save_tfline(aenc.decoder, "default_dec")
+    #
+    # save_tfline(aenc.encoder, "float16_enc", [tf.float16])
+    # save_tfline(aenc.decoder, "float16_dec", [tf.float16])
+    #
+    # save_tfline(aenc.encoder, "bfloat16_enc", [tf.bfloat16])
+    # save_tfline(aenc.decoder, "bfloat16_dec", [tf.bfloat16])
+
+    INT16X8 = tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
+    save_tfline(aenc.encoder, "int16x8_enc", ops=[INT16X8], train_x=x_train)
+    save_tfline(aenc.decoder, "int16x8_dec", ops=[INT16X8], train_x=x_train_enc)
+
+    # save_tfline(aenc.encoder, "int8_enc", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train)
+    # save_tfline(aenc.decoder, "int8_dec", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train_enc)
+
+    print("Testing BER vs SNR")
+    plt.plot(*get_SNR(
+        LiteTFMod("default_enc", aenc),
+        LiteTFDemod("default_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE F32')
+
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("float16_enc", aenc),
+    #     LiteTFDemod("float16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F16')
+    # #
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("bfloat16_enc", aenc),
+    #     LiteTFDemod("bfloat16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE BF16')
+    #
+    plt.plot(*get_SNR(
+        LiteTFMod("int16x8_enc", aenc),
+        LiteTFDemod("int16x8_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE I16x8')
+
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("Autoencoder with different precision data types")
+    plt.legend()
+    plt.savefig('autoencoder_compression.eps', format='eps')
+    plt.show()
+
+    view_encoder(aenc.encoder, 4)
+
+    # aenc = Autoencoder(5, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='5Bit AE')
+    #
+    # aenc = Autoencoder(6, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='6Bit AE')
+    #
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         ber_func=get_AWGN_ber,
+    #         samples=100000,
+    #         steps=50,
+    #         start=-5,
+    #         stop=15,
+    #     ), '-', label=scheme.upper())
+    #
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.title("Autoencoder vs defined modulations")
+    # plt.legend()
+    # plt.show()
+
+
+def _test_autoencoder_perf2():
+    aenc = Autoencoder(2, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='2Bit AE')
+
+    aenc = Autoencoder(3, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='3Bit AE')
+
+    aenc = Autoencoder(4, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='4Bit AE')
+
+    aenc = Autoencoder(5, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                      start=-5, stop=15), '-', label='5Bit AE')
+
+    for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
+        try:
+            plt.plot(
+                *get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50,
+                         start=-5, stop=15, ), '-', label=a.upper())
+        except KeyboardInterrupt:
+            break
+        except Exception:
+            pass
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder vs defined modulations")
+    plt.legend()
+    plt.savefig('autoencoder_mods.eps', format='eps')
+    plt.show()
+
+    # view_encoder(aenc.encoder, 2)
+
+
+def _test_autoencoder_perf_qnn():
+    fh = logging.FileHandler("model_quantizing.log", mode="w+")
+    fh.setLevel(logging.INFO)
+    sh = logging.StreamHandler(stream=stdout)
+    sh.setLevel(logging.INFO)
+
+    logger = logging.getLogger(__name__)
+    logger.setLevel(level=logging.INFO)
+    logger.addHandler(fh)
+    logger.addHandler(sh)
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    # aenc.encoder.save_weights('ae_enc.bin')
+    # aenc.decoder.save_weights('ae_dec.bin')
+    # aenc.encoder.load_weights('ae_enc.bin')
+    # aenc.decoder.load_weights('ae_dec.bin')
+    try:
+        aenc.load_weights('autoencoder')
+    except NotFoundError:
+        aenc.train(epoch_size=1e3, epochs=10)
+        aenc.save_weights('autoencoder')
+
+    aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
+
+    m = aenc.N * aenc.parallel
+    view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
+
+    batch_size = 25000
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size * m).reshape((-1, m)))
+    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000 * m).reshape((-1, m)))
+    bits = [np.log2(i) for i in (32,)][0]
+    alphabet_scalars = 2  # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
+
+    # for b in (3, 6, 8, 12, 16, 24, 32, 48, 64):
+    get_data = (sample for sample in x_train)
+    for i in range(num_layers):
+        get_data = chain(get_data, (sample for sample in x_train))
+
+    qnn = QuantizedNeuralNetwork(
+        network=aenc,
+        batch_size=batch_size,
+        get_data=get_data,
+        logger=logger,
+        bits=np.log2(16),
+        alphabet_scalar=alphabet_scalars,
+    )
+
+    qnn.quantize_network()
+    qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+    view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
+
+    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
+    pass
+
+
+class BitAwareAutoencoder(Autoencoder):
+    def __init__(self, N, channel, **kwargs):
+        super().__init__(N, channel, **kwargs, cost=self.cost)
+        self.BITS = 2 ** N - 1
+        #  data_generator=BinaryGenerator,
+        # self.b2s_layer = BitsToSymbols(2**N)
+        # self.s2b_layer = SymbolsToBits(2**N)
+
+    def cost(self, y_true, y_pred):
+        y = tf.cast(y_true, dtype=tf.float32)
+        z0 = tf.math.argmax(y) / self.BITS
+        z1 = tf.math.argmax(y_pred) / self.BITS
+        error0 = y - y_pred
+        sqr_error0 = K.square(error0)  # mean of the square of the error
+        mean_sqr_error0 = K.mean(sqr_error0)  # square root of the mean of the square of the error
+        sme0 = K.sqrt(mean_sqr_error0)  # return the error
+
+        error1 = z0 - z1
+        sqr_error1 = K.square(error1)
+        mean_sqr_error1 = K.mean(sqr_error1)
+        sme1 = K.sqrt(mean_sqr_error1)
+        return sme0 + tf.cast(sme1 * 300, dtype=tf.float32)
+
+    # def call(self, x, **kwargs):
+    #     x1 = self.b2s_layer(x)
+    #     y = self.encoder(x1)
+    #     z = self.channel(y)
+    #     z1 = self.decoder(z)
+    #     return self.s2b_layer(z1)
+
+
+def _bit_aware_test():
+    aenc = BitAwareAutoencoder(6, -50, bipolar=True)
+
+    try:
+        aenc.load_weights('ae_bitaware')
+    except NotFoundError:
+        pass
+
+    # try:
+    #     hist = aenc.train(
+    #         epochs=70,
+    #         epoch_size=1e3,
+    #         optimizer='adam',
+    #         # metrics=[tf.keras.metrics.Accuracy()],
+    #         # callbacks=[tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, min_delta=0.001)]
+    #     )
+    #     show_train_history(hist, "Autonecoder training history")
+    # except KeyboardInterrupt:
+    #     aenc.save_weights('ae_bitaware')
+    #     exit(0)
+    #
+    # aenc.save_weights('ae_bitaware')
+
+    view_encoder(aenc.encoder, 6, title=f"4bit autoencoder alphabet")
+
+    print("Computing BER/SNR for autoencoder")
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=1000000, steps=40,
+        start=-5, stop=15), '-', label='4Bit AE')
+
+    print("Computing BER/SNR for QAM16")
+    plt.plot(*get_SNR(
+        AlphabetMod('64qam', 10e6),
+        AlphabetDemod('64qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=1000000,
+        steps=40,
+        start=-5,
+        stop=15,
+    ), '-', label='16qam AWGN')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("16QAM vs autoencoder")
+    plt.show()
+
+
+def _graphs():
+
+    y = [5.000000e-06, 1.343333e-04, 3.144667e-04, 1.612333e-03, 1.601600e-03, 2.255200e-03, 2.767467e-03]
+    x = [50, 60, 75, 80, 85, 90, 90]
+    plt.plot(x, y, 'x')
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('Fibre length (km)')
+    plt.ylabel('BER')
+    plt.title("Autoencoder performance")
+    plt.show()
+
+if __name__ == '__main__':
+    _graphs()
+    # _bit_aware_test()
+
+    # _test_autoencoder_perf()
+    # _test_autoencoder_perf_qnn()
+    # _test_autoencoder_perf2()
+    # _test_autoencoder_pretrain()
+    # _test_optics_autoencoder()

+ 67 - 2
tests/misc_test.py

@@ -1,5 +1,10 @@
 import misc
 import numpy as np
+import math
+import itertools
+import tensorflow as tf
+from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
+from matplotlib import pyplot as plt
 
 
 def test_bit_matrix_one_hot():
@@ -11,5 +16,65 @@ def test_bit_matrix_one_hot():
 
 
 if __name__ == "__main__":
-    test_bit_matrix_one_hot()
-    print("Everything passed")
+
+    # cardinality = 8
+    # messages_per_block = 3
+    # num_of_blocks = 10
+    # bits_per_symbol = 3
+    #
+    # #-----------------------------------
+    #
+    # mid_idx = int((messages_per_block - 1) / 2)
+    #
+    # ################################################################################################################
+    #
+    # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+    # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
+    #
+    # # out = enc.fit_transform(rand_int)
+    # out = rand_int
+    #
+    # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+    # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
+    #
+    # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
+    #
+    #
+    # n = int(math.log(cardinality, 2))
+    # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
+    #
+    # pows_np = pows.numpy()
+    #
+    # a = np.asarray([0, 1, 1]).reshape(1, -1)
+    #
+    # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 100
+    NUM_OF_SYMBOLS = 10
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH,
+                                     rx_stddev=0,
+                                     q_stddev=0)
+
+    inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
+
+    inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
+
+    plt.plot(inp_t.flatten())
+
+    out_tf = optical_channel(inp_t)
+
+    out_np = out_tf.numpy()
+
+    plt.plot(out_np.flatten())
+    plt.show()
+
+
+    pass

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 103 - 0
tests/results.py