Bladeren bron

Merge branch 'master' into temp_e2e

# Conflicts:
#	models/end_to_end.py
Min 4 jaren geleden
bovenliggende
commit
58b7fda741
8 gewijzigde bestanden met toevoegingen van 1136 en 70 verwijderingen
  1. 7 1
      .gitignore
  2. 99 67
      models/autoencoder.py
  3. 2 2
      models/basic.py
  4. 32 0
      models/data.py
  5. 111 0
      models/end_to_end.py
  6. 63 0
      models/layers.py
  7. 299 0
      models/quantized_net.py
  8. 523 0
      tests/min_test.py

+ 7 - 1
.gitignore

@@ -5,4 +5,10 @@ __pycache__
 
 # Environments
 venv/
-tests/local_test.py
+
+# Anything else
+*.log
+checkpoint
+*.index
+*.data-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]
+other/

+ 99 - 67
models/autoencoder.py

@@ -8,11 +8,14 @@ from tensorflow.keras import layers, losses
 from tensorflow.keras.models import Model
 from tensorflow.python.keras.layers import LeakyReLU, ReLU
 
-from functools import partial
+# from functools import partial
 import misc
 import defs
 from models import basic
 import os
+# from tensorflow_model_optimization.python.core.quantization.keras import quantize, quantize_aware_activation
+from models.data import BinaryOneHotGenerator
+from models import layers as custom_layers
 
 latent_dim = 64
 
@@ -20,68 +23,98 @@ print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GP
 
 
 class AutoencoderMod(defs.Modulator):
-    def __init__(self, autoencoder):
+    def __init__(self, autoencoder, encoder=None):
         super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
+        self.encoder = encoder or autoencoder.encoder
 
     def forward(self, binary: np.ndarray):
-        reshaped = binary.reshape((-1, self.N))
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
         reshaped_ho = misc.bit_matrix2one_hot(reshaped)
-        encoded = self.autoencoder.encoder(reshaped_ho)
+        encoded = self.encoder(reshaped_ho)
         x = encoded.numpy()
-        x2 = x * 2 - 1
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
 
-        f = np.zeros(x2.shape[0])
-        x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
+
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
         return basic.RFSignal(x3)
 
 
 class AutoencoderDemod(defs.Demodulator):
-    def __init__(self, autoencoder):
+    def __init__(self, autoencoder, decoder=None):
         super().__init__(2 ** autoencoder.N)
         self.autoencoder = autoencoder
+        self.decoder = decoder or autoencoder.decoder
 
     def forward(self, values: defs.Signal) -> np.ndarray:
-        decoded = self.autoencoder.decoder(values.rect).numpy()
-        result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+        decoded = self.decoder(val).numpy()
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
         return result.reshape(-1, )
 
 
 class Autoencoder(Model):
-    def __init__(self, N, channel, signal_dim=2):
+    def __init__(self, N, channel, signal_dim=2, parallel=1, all_onehot=True, bipolar=True, encoder=None, decoder=None):
         super(Autoencoder, self).__init__()
         self.N = N
-        self.encoder = tf.keras.Sequential()
-        self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.encoder.add(LeakyReLU(alpha=0.001))
-        # self.encoder.add(layers.Dropout(0.2))
-        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
-        self.encoder.add(LeakyReLU(alpha=0.001))
-        self.encoder.add(layers.Dense(units=signal_dim, activation="tanh"))
-        # self.encoder.add(layers.ReLU(max_value=1.0))
-
-        self.decoder = tf.keras.Sequential()
-        self.decoder.add(tf.keras.Input(shape=(signal_dim,)))
-        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
-        # leaky relu with alpha=1 gives by far best results
-        self.decoder.add(LeakyReLU(alpha=1))
-        self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
-
+        self.parallel = parallel
+        self.signal_dim = signal_dim
+        self.bipolar = bipolar
+        self._input_shape = 2 ** (N * parallel) if all_onehot else (2 ** N) * parallel
+        if encoder is None:
+            self.encoder = tf.keras.Sequential()
+            self.encoder.add(layers.Input(shape=(self._input_shape,)))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.encoder.add(layers.Dropout(0.2))
+            self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+            self.encoder.add(LeakyReLU(alpha=0.001))
+            self.encoder.add(layers.Dense(units=signal_dim * parallel, activation="sigmoid"))
+            # self.encoder.add(layers.ReLU(max_value=1.0))
+            # self.encoder = quantize.quantize_model(self.encoder)
+        else:
+            self.encoder = encoder
+
+        if decoder is None:
+            self.decoder = tf.keras.Sequential()
+            self.decoder.add(tf.keras.Input(shape=(signal_dim * parallel,)))
+            self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # self.encoder.add(LeakyReLU(alpha=0.001))
+            # self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+            # leaky relu with alpha=1 gives by far best results
+            self.decoder.add(LeakyReLU(alpha=1))
+            self.decoder.add(layers.Dense(units=self._input_shape, activation="softmax"))
+        else:
+            self.decoder = decoder
         # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
 
         self.mod = None
         self.demod = None
         self.compiled = False
 
+        self.channel = tf.keras.Sequential()
+        if self.bipolar:
+            self.channel.add(custom_layers.ScaleAndOffset(2, -1, input_shape=(signal_dim * parallel,)))
+
         if isinstance(channel, int) or isinstance(channel, float):
-            self.channel = basic.AWGNChannel(channel)
+            self.channel.add(custom_layers.AwgnChannel(noise_dB=channel, input_shape=(signal_dim * parallel,)))
         else:
-            if not hasattr(channel, 'forward_tensor'):
-                raise ValueError("Channel has no forward_tensor function")
-            if not callable(channel.forward_tensor):
-                raise ValueError("Channel.forward_tensor is not callable")
-            self.channel = channel
+            if not isinstance(channel, tf.keras.layers.Layer):
+                raise ValueError("Channel is not a keras layer")
+            self.channel.add(channel)
 
         # self.decoder.add(layers.Softmax(units=4, dtype=bool))
 
@@ -95,18 +128,14 @@ class Autoencoder(Model):
         #     layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
         #     layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
         # ])
+    @property
+    def all_layers(self):
+        return self.encoder.layers + self.decoder.layers #self.channel.layers +
 
     def call(self, x, **kwargs):
-        signal = self.encoder(x)
-        signal = signal * 2 - 1
-        signal = self.channel.forward_tensor(signal)
-        # encoded = encoded * 2 - 1
-        # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
-        # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
-        # noise = np.random.normal(0, 1, (1, 2)) * self.noise
-        # noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
-        decoded = self.decoder(signal)
-        return decoded
+        y = self.encoder(x)
+        z = self.channel(y)
+        return self.decoder(z)
 
     def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
         alphabet = basic.load_alphabet(modulation, polar=False)
@@ -152,24 +181,25 @@ class Autoencoder(Model):
 
         print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
 
-    def train(self, samples=1e6):
-        if samples % self.N:
-            samples += self.N - (samples % self.N)
-        x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
-        x_train_ho = misc.bit_matrix2one_hot(x_train)
+    def train(self, epoch_size=3e3, epochs=5):
+        m = self.N * self.parallel
+        x_train = BinaryOneHotGenerator(size=epoch_size, shape=m)
+        x_test = BinaryOneHotGenerator(size=epoch_size*.3, shape=m)
 
-        test_samples = samples * 0.3
-        if test_samples % self.N:
-            test_samples += self.N - (test_samples % self.N)
-        x_test_array = misc.generate_random_bit_array(test_samples)
-        x_test = x_test_array.reshape((-1, self.N))
-        x_test_ho = misc.bit_matrix2one_hot(x_test)
+        # test_samples = epoch_size
+        # if test_samples % m:
+        #     test_samples += m - (test_samples % m)
+        # x_test_array = misc.generate_random_bit_array(test_samples)
+        # x_test = x_test_array.reshape((-1, m))
+        # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
         if not self.compiled:
             self.compile(optimizer='adam', loss=losses.MeanSquaredError())
             self.compiled = True
+            # self.build((self._input_shape, -1))
+            # self.summary()
 
-        self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
+        self.fit(x_train, shuffle=False, validation_data=x_test, epochs=epochs)
         # encoded_data = self.encoder(x_test_ho)
         # decoded_data = self.decoder(encoded_data).numpy()
 
@@ -184,12 +214,14 @@ class Autoencoder(Model):
         return self.demod
 
 
-def view_encoder(encoder, N, samples=1000):
+def view_encoder(encoder, N, samples=1000, title="Autoencoder generated alphabet"):
     test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
     test_values_ho = misc.bit_matrix2one_hot(test_values)
     mvector = np.array([2 ** i for i in range(N)], dtype=int)
     symbols = (test_values * mvector).sum(axis=1)
     encoded = encoder(test_values_ho).numpy()
+    if encoded.shape[1] == 1:
+        encoded = np.c_[encoded, np.zeros(encoded.shape[0])]
     # encoded = misc.polar2rect(encoded)
     for i in range(2 ** N):
         xy = encoded[symbols == i]
@@ -197,7 +229,7 @@ def view_encoder(encoder, N, samples=1000):
         plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
     plt.xlabel('Real')
     plt.ylabel('Imaginary')
-    plt.title("Autoencoder generated alphabet")
+    plt.title(title)
     # plt.legend()
     plt.show()
 
@@ -222,18 +254,18 @@ if __name__ == '__main__':
     # x_test = x_test_array.reshape((-1, n))
     # x_test_ho = misc.bit_matrix2one_hot(x_test)
 
-    autoencoder = Autoencoder(n, -8)
+    autoencoder = Autoencoder(n, -15)
     autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
 
-    autoencoder.fit_encoder(modulation='16qam',
-                            sample_size=2e6,
-                            train_size=0.8,
-                            epochs=1,
-                            batch_size=256,
-                            shuffle=True)
+    # autoencoder.fit_encoder(modulation='16qam',
+    #                         sample_size=2e6,
+    #                         train_size=0.8,
+    #                         epochs=1,
+    #                         batch_size=256,
+    #                         shuffle=True)
 
-    view_encoder(autoencoder.encoder, n)
-    autoencoder.fit_decoder(modulation='16qam', samples=2e6)
+    # view_encoder(autoencoder.encoder, n)
+    # autoencoder.fit_decoder(modulation='16qam', samples=2e6)
     autoencoder.train()
     view_encoder(autoencoder.encoder, n)
 

+ 2 - 2
models/basic.py

@@ -6,13 +6,13 @@ from scipy.spatial import cKDTree
 from os import path
 import tensorflow as tf
 
-ALPHABET_DIR = "./alphabets"
+ALPHABET_DIR = path.join(path.dirname(__file__), "../alphabets")
 
 
 def load_alphabet(name, polar=True):
     apath = path.join(ALPHABET_DIR, name + '.a')
     if not path.exists(apath):
-        raise ValueError(f"Alphabet '{name}' does not exist")
+        raise ValueError(f"Alphabet '{name}' was not found in {path.abspath(apath)}")
     data = []
     indexes = []
     with open(apath, 'r') as f:

+ 32 - 0
models/data.py

@@ -0,0 +1,32 @@
+import tensorflow as tf
+from tensorflow.keras.utils import Sequence
+
+# This creates pool of cpu resources
+# physical_devices = tf.config.experimental.list_physical_devices("CPU")
+# tf.config.experimental.set_virtual_device_configuration(
+#     physical_devices[0], [
+#         tf.config.experimental.VirtualDeviceConfiguration(),
+#         tf.config.experimental.VirtualDeviceConfiguration()
+#     ])
+import misc
+
+
+class BinaryOneHotGenerator(Sequence):
+    def __init__(self, size=1e5, shape=2):
+        size = int(size)
+        if size % shape:
+            size += shape - (size % shape)
+        self.size = size
+        self.shape = shape
+        self.x = None
+        self.on_epoch_end()
+
+    def on_epoch_end(self):
+        x_train = misc.generate_random_bit_array(self.size).reshape((-1, self.shape))
+        self.x = misc.bit_matrix2one_hot(x_train)
+
+    def __len__(self):
+        return self.size
+
+    def __getitem__(self, idx):
+        return self.x, self.x

+ 111 - 0
models/end_to_end.py

@@ -6,6 +6,117 @@ import matplotlib.pyplot as plt
 from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 q_stddev=0.1):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(q_stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 q_stddev=0.01):
+        """
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.digitization_layer = DigitizationLayer(fs=fs,
+                                                    num_of_samples=num_of_samples,
+                                                    lpf_cutoff=lpf_cutoff,
+                                                    q_stddev=q_stddev)
+        self.flatten_layer = layers.Flatten()
+
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
+        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex128)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out
+from tensorflow.keras import layers, losses
 from tensorflow.keras import backend as K
 from custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
 import itertools

+ 63 - 0
models/layers.py

@@ -0,0 +1,63 @@
+"""
+Custom Keras Layers for general use
+"""
+import itertools
+
+from tensorflow.keras import layers
+import tensorflow as tf
+import numpy as np
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
+        """
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__(**kwargs)
+        if noise_dB is not None:
+            # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
+            rx_stddev = 10 ** (noise_dB / 10.0)
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class ScaleAndOffset(layers.Layer):
+    """
+    Scales and offsets a tensor
+    """
+
+    def __init__(self, scale=1, offset=0, **kwargs):
+        super(ScaleAndOffset, self).__init__(**kwargs)
+        self.offset = offset
+        self.scale = scale
+
+    def call(self, inputs, **kwargs):
+        return inputs * self.scale + self.offset
+
+
+class BitsToSymbol(layers.Layer):
+    def __init__(self, cardinality, **kwargs):
+        super().__init__(**kwargs)
+        self.cardinality = cardinality
+        n = int(np.log(self.cardinality, 2))
+        self.powers = tf.convert_to_tensor(
+            np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
+            dtype=tf.float32
+        )
+
+    def call(self, inputs, **kwargs):
+        idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
+        return tf.one_hot(idx, self.cardinality)
+
+
+class SymbolToBits(layers.Layer):
+    def __init__(self, cardinality, **kwargs):
+        super().__init__(**kwargs)
+        n = int(np.log(cardinality, 2))
+        l = [list(i) for i in itertools.product([0, 1], repeat=n)]
+        self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(self.all_syms, inputs)

+ 299 - 0
models/quantized_net.py

@@ -0,0 +1,299 @@
+from numpy import (
+    array,
+    zeros,
+    dot,
+    median,
+    log2,
+    linspace,
+    argmin,
+    abs,
+)
+from scipy.linalg import norm
+from tensorflow.keras.backend import function as Kfunction
+from tensorflow.keras.models import Model, clone_model
+from collections import namedtuple
+from typing import List, Generator
+from time import time
+
+from models.autoencoder import Autoencoder
+
+QuantizedNeuron = namedtuple("QuantizedNeuron", ["layer_idx", "neuron_idx", "q"])
+QuantizedFilter = namedtuple(
+    "QuantizedFilter", ["layer_idx", "filter_idx", "channel_idx", "q_filtr"]
+)
+SegmentedData = namedtuple("SegmentedData", ["wX_seg", "qX_seg"])
+
+
+class QuantizedNeuralNetwork:
+    def __init__(
+            self,
+            network: Model,
+            batch_size: int,
+            get_data: Generator[array, None, None],
+            logger=None,
+            ignore_layers=[],
+            bits=log2(3),
+            alphabet_scalar=1,
+    ):
+
+        self.get_data = get_data
+
+        # The pre-trained network.
+        self.trained_net = network
+
+        # This copies the network structure but not the weights.
+        if isinstance(network, Autoencoder):
+            # The pre-trained network.
+            self.trained_net_layers = network.all_layers
+            self.quantized_net = Autoencoder(network.N, network.channel, bipolar=network.bipolar)
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.all_layers
+            # self.quantized_net.layers = self.quantized_net.layers
+        else:
+            # The pre-trained network.
+            self.trained_net_layers = network.layers
+            self.quantized_net = clone_model(network)
+            # Set all the weights to be the same a priori.
+            self.quantized_net.set_weights(network.get_weights())
+            self.quantized_net_layers = self.quantized_net.layers
+
+        self.batch_size = batch_size
+
+        self.alphabet_scalar = alphabet_scalar
+
+        # Create a dictionary encoding which layers are Dense, and what their dimensions are.
+        self.layer_dims = {
+            layer_idx: layer.get_weights()[0].shape
+            for layer_idx, layer in enumerate(network.layers)
+            if layer.__class__.__name__ == "Dense"
+        }
+
+        # This determines the alphabet. There will be 2**bits atoms in our alphabet.
+        self.bits = bits
+
+        # Construct the (unscaled) alphabet. Layers will scale this alphabet based on the
+        # distribution of that layer's weights.
+        self.alphabet = linspace(-1, 1, num=int(round(2 ** (bits))))
+
+        self.logger = logger
+
+        self.ignore_layers = ignore_layers
+
+    def _log(self, msg: str):
+        if self.logger:
+            self.logger.info(msg)
+        else:
+            print(msg)
+
+    def _bit_round(self, t: float, rad: float) -> float:
+        """Rounds a quantity to the nearest atom in the (scaled) quantization alphabet.
+
+        Parameters
+        -----------
+        t : float
+            The value to quantize.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        # Scale the alphabet appropriately.
+        layer_alphabet = rad * self.alphabet
+        return layer_alphabet[argmin(abs(layer_alphabet - t))]
+
+    def _quantize_weight(
+            self, w: float, u: array, X: array, X_tilde: array, rad: float
+    ) -> float:
+        """Quantizes a single weight of a neuron.
+
+        Parameters
+        -----------
+        w : float
+            The weight.
+        u : array ,
+            Residual vector.
+        X : array
+            Vector from the analog network's random walk.
+        X_tilde : array
+            Vector from the quantized network's random walk.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        bit : float
+            The quantized value.
+        """
+
+        if norm(X_tilde, 2) < 10 ** (-16):
+            return 0
+
+        if abs(dot(X_tilde, u)) < 10 ** (-10):
+            return self._bit_round(w, rad)
+
+        return self._bit_round(dot(X_tilde, u + w * X) / (norm(X_tilde, 2) ** 2), rad)
+
+    def _quantize_neuron(
+            self,
+            layer_idx: int,
+            neuron_idx: int,
+            wX: array,
+            qX: array,
+            rad=1,
+    ) -> QuantizedNeuron:
+        """Quantizes a single neuron in a Dense layer.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        neuron_idx : int,
+            Index of the neuron in the Dense layer.
+        wX : array
+            Layer input for the analog convolutional neural network.
+        qX : array
+            Layer input for the quantized convolutional neural network.
+        rad : float
+            Scaling factor for the quantization alphabet.
+
+        Returns
+        -------
+        QuantizedNeuron: NamedTuple
+            A tuple with the layer and neuron index, as well as the quantized neuron.
+        """
+
+        N_ell = wX.shape[1]
+        u = zeros(self.batch_size)
+        w = self.trained_net_layers[layer_idx].get_weights()[0][:, neuron_idx]
+        q = zeros(N_ell)
+        for t in range(N_ell):
+            q[t] = self._quantize_weight(w[t], u, wX[:, t], qX[:, t], rad)
+            u += w[t] * wX[:, t] - q[t] * qX[:, t]
+
+        return QuantizedNeuron(layer_idx=layer_idx, neuron_idx=neuron_idx, q=q)
+
+    def _get_layer_data(self, layer_idx: int, hf=None):
+        """Gets the input data for the layer at a given index.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the layer.
+        hf: hdf5 File object in write mode.
+            If provided, will write output to hdf5 file instead of returning directly.
+
+        Returns
+        -------
+        tuple: (array, array)
+            A tuple of arrays, with the first entry being the input for the analog network
+            and the latter being the input for the quantized network.
+        """
+
+        layer = self.trained_net_layers[layer_idx]
+        layer_data_shape = layer.input_shape[1:] if layer.input_shape[0] is None else layer.input_shape
+        wX = zeros((self.batch_size, *layer_data_shape))
+        qX = zeros((self.batch_size, *layer_data_shape))
+        if layer_idx == 0:
+            for sample_idx in range(self.batch_size):
+                try:
+                    wX[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+            qX = wX
+        else:
+            # Define functions which will give you the output of the previous hidden layer
+            # for both networks.
+            prev_trained_output = Kfunction(
+                [self.trained_net_layers[0].input],
+                [self.trained_net_layers[layer_idx - 1].output],
+            )
+            prev_quant_output = Kfunction(
+                [self.quantized_net_layers[0].input],
+                [self.quantized_net_layers[layer_idx - 1].output],
+            )
+            input_layer = self.trained_net_layers[0]
+            input_shape = input_layer.input_shape[1:] if input_layer.input_shape[0] is None else input_layer.input_shape
+            batch = zeros((self.batch_size, *input_shape))
+
+            # TODO: Add hf option here. Feed batches of data through rather than all at once. You may want
+            # to reconsider how much memory you preallocate for batch, wX, and qX.
+            feed_foward_batch_size = 500
+            ctr = 0
+            for sample_idx in range(self.batch_size):
+                try:
+                    batch[sample_idx, :] = next(self.get_data)
+                except StopIteration:
+                    # No more samples!
+                    break
+
+            wX = prev_trained_output([batch])[0]
+            qX = prev_quant_output([batch])[0]
+
+        return (wX, qX)
+
+    def _update_weights(self, layer_idx: int, Q: array):
+        """Updates the weights of the quantized neural network given a layer index and
+        quantized weights.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Conv2D layer.
+        Q : array
+            The quantized weights.
+        """
+
+        # Update the quantized network. Use the same bias vector as in the analog network for now.
+        if self.trained_net_layers[layer_idx].use_bias:
+            bias = self.trained_net_layers[layer_idx].get_weights()[1]
+            self.quantized_net_layers[layer_idx].set_weights([Q, bias])
+        else:
+            self.quantized_net_layers[layer_idx].set_weights([Q])
+
+    def _quantize_layer(self, layer_idx: int):
+        """Quantizes a Dense layer of a multi-layer perceptron.
+
+        Parameters
+        -----------
+        layer_idx : int
+            Index of the Dense layer.
+        """
+
+        W = self.trained_net_layers[layer_idx].get_weights()[0]
+        N_ell, N_ell_plus_1 = W.shape
+        # Placeholder for the weight matrix in the quantized network.
+        Q = zeros(W.shape)
+        N_ell_plus_1 = W.shape[1]
+        wX, qX = self._get_layer_data(layer_idx)
+
+        # Set the radius of the alphabet.
+        rad = self.alphabet_scalar * median(abs(W.flatten()))
+
+        for neuron_idx in range(N_ell_plus_1):
+            self._log(f"\tQuantizing neuron {neuron_idx} of {N_ell_plus_1}...")
+            tic = time()
+            qNeuron = self._quantize_neuron(layer_idx, neuron_idx, wX, qX, rad)
+            Q[:, neuron_idx] = qNeuron.q
+
+            self._log(f"\tdone. {time() - tic :.2f} seconds.")
+
+            self._update_weights(layer_idx, Q)
+
+    def quantize_network(self):
+        """Quantizes all Dense layers that are not specified by the list of ignored layers."""
+
+        # This must be done sequentially.
+        for layer_idx, layer in enumerate(self.trained_net_layers):
+            if (
+                    layer.__class__.__name__ == "Dense"
+                    and layer_idx not in self.ignore_layers
+            ):
+                # Only quantize dense layers.
+                self._log(f"Quantizing layer {layer_idx}...")
+                self._quantize_layer(layer_idx)
+                self._log(f"done. {layer_idx}...")

+ 523 - 0
tests/min_test.py

@@ -0,0 +1,523 @@
+"""
+These are some unstructured tests. Feel free to use this code for anything else
+"""
+
+import logging
+import pathlib
+from itertools import chain
+from sys import stdout
+
+from tensorflow.python.framework.errors_impl import NotFoundError
+
+import defs
+from graphs import get_SNR, get_AWGN_ber
+from models import basic
+from models.autoencoder import Autoencoder, view_encoder
+import matplotlib.pyplot as plt
+import tensorflow as tf
+import misc
+import numpy as np
+
+from models.basic import AlphabetDemod, AlphabetMod
+from models.optical_channel import OpticalChannel
+from models.quantized_net import QuantizedNeuralNetwork
+
+
+
+
+
+def _test_optics_autoencoder():
+    ch = OpticalChannel(
+        noise_level=-10,
+        dispersion=-21.7,
+        symbol_rate=10e9,
+        sample_rate=400e9,
+        length=10,
+        pulse_shape='rcos',
+        sqrt_out=True
+    )
+
+    tf.executing_eagerly()
+
+    aenc = Autoencoder(4, channel=ch)
+    aenc.train(samples=1e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('4pam', 10e6),
+        AlphabetDemod('4pam', 10e6),
+        samples=30000,
+        steps=50,
+        start=-5,
+        stop=15,
+        length=1,
+        pulse_shape='rcos'
+    ), '-', label='4PAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder Performance")
+    plt.legend()
+    plt.savefig('optics_autoencoder.eps', format='eps')
+    plt.show()
+
+
+
+def _test_autoencoder_pretrain():
+    # aenc = Autoencoder(4, -25)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='Random AE')
+
+    aenc = Autoencoder(4, -25)
+    # aenc.fit_encoder('16qam', 3e4)
+    aenc.fit_decoder('16qam', 1e5)
+
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Pre-trained AE')
+
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(
+        aenc.get_modulator(),
+        aenc.get_demodulator(),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM Post-trained AE')
+
+    plt.plot(*get_SNR(
+        AlphabetMod('16qam', 10e6),
+        AlphabetDemod('16qam', 10e6),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='16QAM')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("4Bit Autoencoder Performance")
+    plt.legend()
+    plt.show()
+
+
+class LiteTFMod(defs.Modulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+        pass
+
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, (self.N * self.autoencoder.parallel)))
+        reshaped_ho = misc.bit_matrix2one_hot(reshaped)
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        x = np.zeros((len(reshaped_ho), output_shape[1]))
+        for i, ho in enumerate(reshaped_ho):
+            self.interpreter.set_tensor(input_index, ho.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            x[i] = self.interpreter.get_tensor(output_index)
+
+        if self.autoencoder.bipolar:
+            x = x * 2 - 1
+
+        if self.autoencoder.parallel > 1:
+            x = x.reshape((-1, self.autoencoder.signal_dim))
+
+        f = np.zeros(x.shape[0])
+        if self.autoencoder.signal_dim <= 1:
+            p = np.zeros(x.shape[0])
+        else:
+            p = x[:, 1]
+        x3 = misc.rect2polar(np.c_[x[:, 0], p, f])
+        return basic.RFSignal(x3)
+
+
+class LiteTFDemod(defs.Demodulator):
+    def __init__(self, name, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        self.interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
+        self.interpreter.allocate_tensors()
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        if self.autoencoder.signal_dim <= 1:
+            val = values.rect_x
+        else:
+            val = values.rect
+        if self.autoencoder.parallel > 1:
+            val = val.reshape((-1, self.autoencoder.parallel))
+
+        input_index = self.interpreter.get_input_details()[0]["index"]
+        input_dtype = self.interpreter.get_input_details()[0]["dtype"]
+        input_shape = self.interpreter.get_input_details()[0]["shape"]
+        output_index = self.interpreter.get_output_details()[0]["index"]
+        output_shape = self.interpreter.get_output_details()[0]["shape"]
+
+        decoded = np.zeros((len(val), output_shape[1]))
+        for i, v in enumerate(val):
+            self.interpreter.set_tensor(input_index, v.reshape(input_shape).astype(input_dtype))
+            self.interpreter.invoke()
+            decoded[i] = self.interpreter.get_tensor(output_index)
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N * self.autoencoder.parallel)
+        return result.reshape(-1, )
+
+
+def _test_autoencoder_perf():
+    assert float(tf.__version__[:3]) >= 2.3
+
+    # aenc = Autoencoder(3, -15)
+    # aenc.train(samples=1e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='3Bit AE')
+
+    # aenc = Autoencoder(4, -25, bipolar=True, dtype=tf.float64)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4Bit AE F64')
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    aenc.train(epoch_size=1e3, epochs=10)
+    # #
+    m = aenc.N * aenc.parallel
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(100*m).reshape((-1, m)))
+    x_train_enc = aenc.encoder(x_train)
+    x_train = tf.cast(x_train, tf.float32)
+    #
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F32')
+    # # #
+    def save_tfline(model, name, types=None, ops=None, io_types=None, train_x=None):
+        converter = tf.lite.TFLiteConverter.from_keras_model(model)
+        if types is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_types = types
+        if ops is not None:
+            converter.optimizations = [tf.lite.Optimize.DEFAULT]
+            converter.target_spec.supported_ops = ops
+        if io_types is not None:
+            converter.inference_input_type = io_types
+            converter.inference_output_type = io_types
+        if train_x is not None:
+            def representative_data_gen():
+                for input_value in tf.data.Dataset.from_tensor_slices(train_x).batch(1).take(100):
+                    yield [input_value]
+            converter.representative_dataset = representative_data_gen
+        tflite_model = converter.convert()
+        tflite_models_dir = pathlib.Path("/tmp/tflite/")
+        tflite_models_dir.mkdir(exist_ok=True, parents=True)
+        tflite_model_file = tflite_models_dir / (name + ".tflite")
+        tflite_model_file.write_bytes(tflite_model)
+
+    print("Saving models")
+
+    save_tfline(aenc.encoder, "default_enc")
+    save_tfline(aenc.decoder, "default_dec")
+    #
+    # save_tfline(aenc.encoder, "float16_enc", [tf.float16])
+    # save_tfline(aenc.decoder, "float16_dec", [tf.float16])
+    #
+    # save_tfline(aenc.encoder, "bfloat16_enc", [tf.bfloat16])
+    # save_tfline(aenc.decoder, "bfloat16_dec", [tf.bfloat16])
+
+    INT16X8 = tf.lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
+    save_tfline(aenc.encoder, "int16x8_enc", ops=[INT16X8], train_x=x_train)
+    save_tfline(aenc.decoder, "int16x8_dec", ops=[INT16X8], train_x=x_train_enc)
+
+    # save_tfline(aenc.encoder, "int8_enc", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train)
+    # save_tfline(aenc.decoder, "int8_dec", ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8], io_types=tf.uint8, train_x=x_train_enc)
+
+    print("Testing BER vs SNR")
+    plt.plot(*get_SNR(
+        LiteTFMod("default_enc", aenc),
+        LiteTFDemod("default_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE F32')
+
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("float16_enc", aenc),
+    #     LiteTFDemod("float16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE F16')
+    # #
+    # plt.plot(*get_SNR(
+    #     LiteTFMod("bfloat16_enc", aenc),
+    #     LiteTFDemod("bfloat16_dec", aenc),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='4AE BF16')
+    #
+    plt.plot(*get_SNR(
+        LiteTFMod("int16x8_enc", aenc),
+        LiteTFDemod("int16x8_dec", aenc),
+        ber_func=get_AWGN_ber,
+        samples=100000,
+        steps=50,
+        start=-5,
+        stop=15
+    ), '-', label='4AE I16x8')
+
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam')
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.ylabel('BER')
+    plt.title("Autoencoder with different precision data types")
+    plt.legend()
+    plt.savefig('autoencoder_compression.eps', format='eps')
+    plt.show()
+
+    view_encoder(aenc.encoder, 4)
+
+    # aenc = Autoencoder(5, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='5Bit AE')
+    #
+    # aenc = Autoencoder(6, -25)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_SNR(
+    #     aenc.get_modulator(),
+    #     aenc.get_demodulator(),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15
+    # ), '-', label='6Bit AE')
+    #
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         ber_func=get_AWGN_ber,
+    #         samples=100000,
+    #         steps=50,
+    #         start=-5,
+    #         stop=15,
+    #     ), '-', label=scheme.upper())
+    #
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.title("Autoencoder vs defined modulations")
+    # plt.legend()
+    # plt.show()
+
+
+def _test_autoencoder_perf2():
+    aenc = Autoencoder(2, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='2Bit AE')
+
+    aenc = Autoencoder(3, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='3Bit AE')
+
+    aenc = Autoencoder(4, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='4Bit AE')
+
+    aenc = Autoencoder(5, -20)
+    aenc.train(samples=3e6)
+    plt.plot(*get_SNR(aenc.get_modulator(), aenc.get_demodulator(), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15), '-', label='5Bit AE')
+
+    for a in ['qpsk', '8psk', '16qam', '32qam', '64qam']:
+        try:
+            plt.plot(*get_SNR(AlphabetMod(a, 10e6), AlphabetDemod(a, 10e6), ber_func=get_AWGN_ber, samples=100000, steps=50, start=-5, stop=15,), '-', label=a.upper())
+        except KeyboardInterrupt:
+            break
+        except Exception:
+            pass
+
+    plt.yscale('log')
+    plt.grid()
+    plt.xlabel('SNR dB')
+    plt.title("Autoencoder vs defined modulations")
+    plt.legend()
+    plt.savefig('autoencoder_mods.eps', format='eps')
+    plt.show()
+
+    # view_encoder(aenc.encoder, 2)
+
+
+def _test_autoencoder_perf_qnn():
+    fh = logging.FileHandler("model_quantizing.log", mode="w+")
+    fh.setLevel(logging.INFO)
+    sh = logging.StreamHandler(stream=stdout)
+    sh.setLevel(logging.INFO)
+
+    logger = logging.getLogger(__name__)
+    logger.setLevel(level=logging.INFO)
+    logger.addHandler(fh)
+    logger.addHandler(sh)
+
+    aenc = Autoencoder(4, -25, bipolar=True)
+    # aenc.encoder.save_weights('ae_enc.bin')
+    # aenc.decoder.save_weights('ae_dec.bin')
+    # aenc.encoder.load_weights('ae_enc.bin')
+    # aenc.decoder.load_weights('ae_dec.bin')
+    try:
+        aenc.load_weights('autoencoder')
+    except NotFoundError:
+        aenc.train(epoch_size=1e3, epochs=10)
+        aenc.save_weights('autoencoder')
+
+    aenc.compile(optimizer='adam', loss=tf.losses.MeanSquaredError())
+
+    m = aenc.N * aenc.parallel
+    view_encoder(aenc.encoder, 4, title="FP32 Alphabet")
+
+    batch_size = 25000
+    x_train = misc.bit_matrix2one_hot(misc.generate_random_bit_array(batch_size*m).reshape((-1, m)))
+    x_test = misc.bit_matrix2one_hot(misc.generate_random_bit_array(5000*m).reshape((-1, m)))
+    bits = [np.log2(i) for i in (32,)][0]
+    alphabet_scalars = 2  # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    num_layers = sum([layer.__class__.__name__ in ('Dense',) for layer in aenc.all_layers])
+
+    # for b in (3, 6, 8, 12, 16, 24, 32, 48, 64):
+    get_data = (sample for sample in x_train)
+    for i in range(num_layers):
+        get_data = chain(get_data, (sample for sample in x_train))
+
+    qnn = QuantizedNeuralNetwork(
+        network=aenc,
+        batch_size=batch_size,
+        get_data=get_data,
+        logger=logger,
+        bits=np.log2(16),
+        alphabet_scalar=alphabet_scalars,
+    )
+
+    qnn.quantize_network()
+    qnn.quantized_net.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+    view_encoder(qnn.quantized_net, 4, title=f"Quantised {16}b alphabet")
+
+
+
+
+    # view_encoder(qnn_enc.quantized_net, 4, title=f"Quantised {b}b alphabet")
+
+
+    # _, q_accuracy = qnn.quantized_net.evaluate(x_test, x_test, verbose=True)
+    pass
+
+if __name__ == '__main__':
+
+
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     ber_func=get_AWGN_ber,
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam AWGN')
+    #
+    # plt.plot(*get_SNR(
+    #     AlphabetMod('16qam', 10e6),
+    #     AlphabetDemod('16qam', 10e6),
+    #     samples=100000,
+    #     steps=50,
+    #     start=-5,
+    #     stop=15,
+    # ), '-', label='16qam OPTICAL')
+    #
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.show()
+
+    # _test_autoencoder_perf()
+    _test_autoencoder_perf_qnn()
+    # _test_autoencoder_perf2()
+    # _test_autoencoder_pretrain()
+    # _test_optics_autoencoder()
+