35 Commits e71f9723b5 ... f56129796c

Autore SHA1 Messaggio Data
  Tharmetharan Balendran f56129796c fixed broken keras import 5 anni fa
  Tharmetharan Balendran 7e9ac81eb0 added comments 5 anni fa
  Tharmetharan Balendran be6243754e Working implementation of end-to-end AE 5 anni fa
  Tharmetharan Balendran 25393c2ae8 AE training and testing 5 anni fa
  Tharmetharan Balendran 8873f6b053 initial implementation of end-to-end AE 5 anni fa
  Min 6fa2f7d82c Merge branch 'photonics' 5 anni fa
  Min fe3da6cd95 Quick fix signal model for autoencoder 5 anni fa
  Min efbf971909 improved singal class compatability 5 anni fa
  Min 2677c5e41f Merge remote-tracking branch 'origin/matched_filter' into photonics 5 anni fa
  Min ba192057f0 Added signal definition 5 anni fa
  Tharmetharan Balendran 2e157f7def optical channel bugfix 5 anni fa
  Min 294eb2b46c Train decoder to match some modulation 5 anni fa
  Tharmetharan Balendran 372278f5c7 encoder training 5 anni fa
  Min 5d6c327cdf Multithreaded SNR calc + SNR for optics 5 anni fa
  Min 7d5831b735 Merge branch 'matched_filter' 5 anni fa
  Min b7fdda8b86 Merge remote-tracking branch 'origin/swarm_optimisation' 5 anni fa
  Min 8810bce77c Non-dividable sample number fix 5 anni fa
  Oliver Jaison 9955b46ead Function for SNR 5 anni fa
  Tharmetharan Balendran 6d04f26a97 ber plot for optical channel 5 anni fa
  Tharmetharan Balendran 24307ed05e pulse shaping achieved 5 anni fa
  Min 28162f72c0 Photonics structure prototype 5 anni fa
  Min 2b1d868ff6 Merge remote-tracking branch 'origin/chromatic_dispersion' 5 anni fa
  Min 58a577e847 Merge remote-tracking branch 'origin/chromatic_dispersion' 5 anni fa
  Min 34c1776176 Added 32QAM 5 anni fa
  Tharmetharan Balendran 17670fda2d Photodiode detection added 5 anni fa
  Min 7632f0d3fc Testing autoencoders with other modulations 5 anni fa
  Min ea0fb8a718 Alphabet mod/demod 5 anni fa
  Min 2e9803f2d3 quick fix for plt 5 anni fa
  Min 5de0515f99 Added one-hot to autoencoder 5 anni fa
  Min 7151c964d3 Added one-hot functions 5 anni fa
  Min 112e968014 autoencoder WIP 5 anni fa
  Tharmetharan Balendran 6d656d9375 chromatic dispersion applied 5 anni fa
  Min dbaa4ca270 Working autoencoder prototype 5 anni fa
  Min 274c6a1a69 WIP 5 anni fa
  Tharmetharan Balendran bd50e367c2 python venv gitignore 5 anni fa
18 ha cambiato i file con 1386 aggiunte e 143 eliminazioni
  1. 4 0
      .gitignore
  2. 20 0
      alphabets/16qam.a
  3. 38 0
      alphabets/32qam.a
  4. 5 0
      alphabets/4pam.a
  5. 72 0
      alphabets/64qam.a
  6. 9 0
      alphabets/8psk.a
  7. 3 0
      alphabets/bpsk.a
  8. 5 0
      alphabets/qpsk.a
  9. 33 6
      defs.py
  10. 88 0
      graphs.py
  11. 113 43
      main.py
  12. 38 1
      misc.py
  13. 258 0
      models/autoencoder.py
  14. 107 93
      models/basic.py
  15. 369 0
      models/end_to_end.py
  16. 177 0
      models/optical_channel.py
  17. 32 0
      photonics.py
  18. 15 0
      tests/misc_test.py

+ 4 - 0
.gitignore

@@ -2,3 +2,7 @@ __pycache__
 .idea
 *.pyc
 *.pyo
+
+# Environments
+venv/
+tests/local_test.py

+ 20 - 0
alphabets/16qam.a

@@ -0,0 +1,20 @@
+R
+1000,-1,1
+1001,-0.33333,1
+1011,0.33333,1
+1010,1,1
+
+1100,-1,0.33333
+1101,-0.33333,0.33333
+1111,0.33333,0.33333
+1110,1,0.33333
+
+0100,-1,-0.33333
+0101,-0.33333,-0.33333
+0111,0.33333,-0.33333
+0110,1,-0.33333
+
+0000,-1,-1
+0001,-0.33333,-1
+0011,0.33333,-1
+0010,1,-1

+ 38 - 0
alphabets/32qam.a

@@ -0,0 +1,38 @@
+RI
+0,-0.6,1
+1,-0.2,1
+29,+0.2,1
+28,+0.6,1
+
+4,-1,0.6
+8,-0.6,0.6
+12,-0.2,0.6
+16,+0.2,0.6
+20,+0.6,0.6
+24,+1,0.6
+
+5,-1,0.2
+9,-0.6,0.2
+13,-0.2,0.2
+17,+0.2,0.2
+21,+0.6,0.2
+25,+1,0.2
+
+6,-1,-0.2
+10,-0.6,-0.2
+14,-0.2,-0.2
+18,+0.2,-0.2
+22,+0.6,-0.2
+26,+1,-0.2
+
+7,-1,-0.6
+11,-0.6,-0.6
+15,-0.2,-0.6
+19,+0.2,-0.6
+23,+0.6,-0.6
+27,+1,-0.6
+
+3,-0.6,-1
+2,-0.2,-1
+30,+0.2,-1
+31,+0.6,-1

+ 5 - 0
alphabets/4pam.a

@@ -0,0 +1,5 @@
+R
+00,0,0
+01,0.25,0
+10,0.5,0
+11,1,0

+ 72 - 0
alphabets/64qam.a

@@ -0,0 +1,72 @@
+R
+000000,-1,1
+000001,-0.7142857,1
+000011,-0.4285714,1
+000010,-0.1428571,1
+000110,0.1428571,1
+000111,0.4285714,1
+000101,0.7142857,1
+000100,1,1
+
+001000,-1,0.7142857
+001001,-0.7142857,0.7142857
+001011,-0.4285714,0.7142857
+001010,-0.1428571,0.7142857
+001110,0.1428571,0.7142857
+001111,0.4285714,0.7142857
+001101,0.7142857,0.7142857
+001100,1,0.7142857
+
+011000,-1,0.4285714
+011001,-0.7142857,0.4285714
+011011,-0.4285714,0.4285714
+011010,-0.1428571,0.4285714
+011110,0.1428571,0.4285714
+011111,0.4285714,0.4285714
+011101,0.7142857,0.4285714
+011100,1,0.4285714
+
+010000,-1,0.1428571
+010001,-0.7142857,0.1428571
+010011,-0.4285714,0.1428571
+010010,-0.1428571,0.1428571
+010110,0.1428571,0.1428571
+010111,0.4285714,0.1428571
+010101,0.7142857,0.1428571
+010100,1,0.1428571
+
+110000,-1,-0.1428571
+110001,-0.7142857,-0.1428571
+110011,-0.4285714,-0.1428571
+110010,-0.1428571,-0.1428571
+110110,0.1428571,-0.1428571
+110111,0.4285714,-0.1428571
+110101,0.7142857,-0.1428571
+110100,1,-0.1428571
+
+111000,-1,-0.4285714
+111001,-0.7142857,-0.4285714
+111011,-0.4285714,-0.4285714
+111010,-0.1428571,-0.4285714
+111110,0.1428571,-0.4285714
+111111,0.4285714,-0.4285714
+111101,0.7142857,-0.4285714
+111100,1,-0.4285714
+
+101000,-1,-0.7142857
+101001,-0.7142857,-0.7142857
+101011,-0.4285714,-0.7142857
+101010,-0.1428571,-0.7142857
+101110,0.1428571,-0.7142857
+101111,0.4285714,-0.7142857
+101101,0.7142857,-0.7142857
+101100,1,-0.7142857
+
+100000,-1,-1
+100001,-0.7142857,-1
+100011,-0.4285714,-1
+100010,-0.1428571,-1
+100110,0.1428571,-1
+100111,0.4285714,-1
+100101,0.7142857,-1
+100100,1,-1

+ 9 - 0
alphabets/8psk.a

@@ -0,0 +1,9 @@
+D8
+111,1,0
+110,1,45
+010,1,90
+011,1,135
+001,1,180
+000,1,225
+100,1,270
+101,1,315

+ 3 - 0
alphabets/bpsk.a

@@ -0,0 +1,3 @@
+R2
+1,0
+-1,0

+ 5 - 0
alphabets/qpsk.a

@@ -0,0 +1,5 @@
+R4
+1,1
+-1,1
+-1,-1
+1,-1

+ 33 - 6
defs.py

@@ -1,5 +1,30 @@
 import math
 import numpy as np
+import tensorflow as tf
+
+class Signal:
+
+    @property
+    def rect_x(self) -> np.ndarray:
+        return self.rect[:, 0]
+
+    @property
+    def rect_y(self) -> np.ndarray:
+        return self.rect[:, 1]
+
+    @property
+    def rect(self) -> np.ndarray:
+        raise NotImplemented("Not implemented")
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    def set_rect(self, mat: np.ndarray):
+        raise NotImplemented("Not implemented")
+
+    @property
+    def apf(self):
+        raise NotImplemented("Not implemented")
 
 
 class COMComponent:
@@ -13,12 +38,14 @@ class Channel(COMComponent):
     This model is just empty therefore just bypasses any input to output
     """
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> Signal:
+        raise NotImplemented("Need to define forward function")
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
         """
-        :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
-        :return: affected tuple of (amplitude, phase, frequency)
+        Forward operation optimised for tensorflow tensors
         """
-        raise NotImplemented("Need to define forward function")
+        raise NotImplemented("Need to define forward_tensor function")
 
 
 class ModComponent(COMComponent):
@@ -30,7 +57,7 @@ class ModComponent(COMComponent):
 
 class Modulator(ModComponent):
 
-    def forward(self, binary: np.ndarray) -> np.ndarray:
+    def forward(self, binary: np.ndarray) -> Signal:
         """
         :param binary: raw bytes as input (most be dtype=bool)
         :return: amplitude, phase, frequency
@@ -40,7 +67,7 @@ class Modulator(ModComponent):
 
 class Demodulator(ModComponent):
 
-    def forward(self, values: np.ndarray) -> np.ndarray:
+    def forward(self, values: Signal) -> np.ndarray:
         """
         :param values: value generator, each iteration returns tuple of (amplitude, phase, frequency)
         :return: binary resulting values (dtype=bool)

+ 88 - 0
graphs.py

@@ -0,0 +1,88 @@
+import math
+import os
+from multiprocessing import Pool
+
+from sklearn.metrics import accuracy_score
+
+from defs import Modulator, Demodulator, Channel
+from models.basic import AWGNChannel
+from misc import generate_random_bit_array
+from models.optical_channel import OpticalChannel
+import matplotlib.pyplot as plt
+import numpy as np
+
+CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
+
+
+def show_constellation(mod: Modulator, chan: Channel, demod: Demodulator, samples=1000):
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+
+    plt.plot(x_chan.rect_x[x], x_chan.rect_y[x], '+')
+    plt.plot(x_chan.rect_x[:, 0][~x], x_chan.rect_y[:, 1][~x], '+')
+    plt.plot(x_mod.rect_x[:, 0], x_mod.rect_y[:, 1], 'ro')
+    axes = plt.gca()
+    axes.set_xlim([-2, +2])
+    axes.set_ylim([-2, +2])
+    plt.grid()
+    plt.show()
+    print('Accuracy : ' + str())
+
+
+def get_ber(mod, chan, demod, samples=1000):
+    if samples % mod.N:
+        samples += mod.N - (samples % mod.N)
+    x = generate_random_bit_array(samples)
+    x_mod = mod.forward(x)
+    x_chan = chan.forward(x_mod)
+    x_demod = demod.forward(x_chan)
+    return 1 - accuracy_score(x, x_demod)
+
+
+def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    for noise in ber_x:
+        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
+    return ber_x, ber_y
+
+
+def __calc_ber(packed):
+    # This function has to be outside get_Optical_ber in order to be pickled by pool
+    mod, demod, noise, length, pulse_shape, samples = packed
+    tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                length=length, pulse_shape=pulse_shape, sqrt_out=True)
+    return get_ber(mod, tx_channel, demod, samples=samples)
+
+
+def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
+    with Pool(CPU_COUNT) as pool:
+        packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
+        for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
+            ber_y.append(ber)
+            i += 1  # just offset by 1
+            print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i * 100 / len(ber_x):6.2f}%)", end='')
+    print()
+    return ber_x, ber_y
+
+
+def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
+    """
+    SNR for optics and RF should be calculated the same, that is A^2
+    Because P∝V² and P∝I²
+    """
+    x_mod = mod.forward(generate_random_bit_array(samples * mod.N))
+    sig_power = [A ** 2 for A in x_mod.amplitude]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+
+    noise_start = -start + av_sig_pow
+    noise_stop = -stop + av_sig_pow
+    ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
+    SNR = -ber_x + av_sig_pow
+    return SNR, ber_y

+ 113 - 43
main.py

@@ -1,58 +1,128 @@
 import matplotlib.pyplot as plt
 
-import numpy as np
-from sklearn.metrics import accuracy_score
-from models import basic
-from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, MaryMod, MaryDemod
-import misc
-
-def show_constellation(mod, chan, demod, samples=1000):
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-
-    x_mod_rect = misc.polar2rect(x_mod)
-    x_chan_rect = misc.polar2rect(x_chan)
-    plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
-    plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
-    plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
-    axes = plt.gca()
-    axes.set_xlim([-2, +2])
-    axes.set_ylim([-2, +2])
-    plt.grid()
-    plt.show()
-    print('Accuracy : ' + str())
+import graphs
+from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 
 
-def get_ber(mod, chan, demod, samples=1000):
-    x = misc.generate_random_bit_array(samples)
-    x_mod = mod.forward(x)
-    x_chan = chan.forward(x_mod)
-    x_demod = demod.forward(x_chan)
-    return 1 - accuracy_score(x, x_demod)
+if __name__ == '__main__':
+    # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
+    # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
+    # mod = MaryMod('8psk', 10e6)
+    # misc.display_alphabet(mod.alphabet, a_vals=True)
+    # mod = MaryMod('qpsk', 10e6)
+    # misc.display_alphabet(mod.alphabet, a_vals=True)
+    # mod = MaryMod('16qam', 10e6)
+    # misc.display_alphabet(mod.alphabet, a_vals=True)
+    # mod = MaryMod('64qam', 10e6)
+    # misc.display_alphabet(mod.alphabet, a_vals=True)
+    # aenc = Autoencoder(4, -25)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 4bit -25dB')
 
-def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
-    ber_x = np.linspace(start, stop, steps)
-    ber_y = []
-    for noise in ber_x:
-        ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
-    return ber_x, ber_y
+    # aenc = Autoencoder(5, -25)
+    # aenc.train(samples=2e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 5bit -25dB')
 
+    # view_encoder(aenc.encoder, 5)
+    # plt.plot(*get_AWGN_ber(AlphabetMod('32qam', 10e6), AlphabetDemod('32qam', 10e6), samples=12000, start=-15), '-',
+    #          label='32-QAM')
+    # show_constellation(AlphabetMod('32qam', 10e6), AWGNChannel(-1), AlphabetDemod('32qam', 10e6))
+    # mod = AlphabetMod('32qam', 10e6)
+    # misc.display_alphabet(mod.alphabet, a_vals=True)
+    # pass
 
-if __name__ == '__main__':
+    # aenc = Autoencoder(5, -15)
+    # aenc.train(samples=2e6)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 5bit -15dB')
+    #
+    # aenc = Autoencoder(4, -25)
+    # aenc.train(samples=6e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 4bit -20dB')
+    #
+    # aenc = Autoencoder(4, -15)
+    # aenc.train(samples=6e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 4bit -15dB')
+
+    # aenc = Autoencoder(2, -20)
+    # aenc.train(samples=6e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 2bit -20dB')
+    #
+    # aenc = Autoencoder(2, -15)
+    # aenc.train(samples=6e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 2bit -15dB')
+
+    # aenc = Autoencoder(4, -10)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 4bit -10dB')
+    #
+    # aenc = Autoencoder(4, -8)
+    # aenc.train(samples=5e5)
+    # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
+    #          label='AE 4bit -8dB')
+
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    # for l in np.logspace(start=0, stop=3, num=6):
+    #     plt.plot(*misc.get_SNR(
+    #         AlphabetMod('4pam', 10e6),
+    #         AlphabetDemod('4pam', 10e6),
+    #         samples=2000,
+    #         steps=200,
+    #         start=-5,
+    #         stop=20,
+    #         length=l,
+    #         pulse_shape='rcos'
+    #     ), '-', label=(str(int(l))+'km'))
+    #
+    # plt.yscale('log')
+    # # plt.gca().invert_xaxis()
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # # plt.ylabel('BER')
+    # plt.title("BER against Fiber length")
+    # plt.legend()
+    # plt.show()
+
+    for ps in ['rect', 'rcos', 'rrcos']:
+        plt.plot(*graphs.get_SNR(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=30000,
+            steps=100,
+            start=-5,
+            stop=20,
+            length=1,
+            pulse_shape=ps
+        ), '-', label=ps)
 
-    plt.plot(*get_AWGN_ber(MaryMod(6, 10e6, gray=True), MaryDemod(6, 10e6), samples=12000, start=-15), '-', label='64-QAM')
-    plt.plot(*get_AWGN_ber(MaryMod(5, 10e6, gray=True), MaryDemod(5, 10e6), samples=12000, start=-15), '-', label='32-QAM')
-    plt.plot(*get_AWGN_ber(MaryMod(4, 10e6, gray=True), MaryDemod(4, 10e6), samples=12000, start=-15), '-', label='16-QAM')
-    plt.plot(*get_AWGN_ber(BPSKMod(10e6), BPSKDemod(10e6, 10e3), samples=12000), '-', label='BPSK')
     plt.yscale('log')
-    plt.gca().invert_xaxis()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.show()
 
-    pass
+    pass

+ 38 - 1
misc.py

@@ -1,3 +1,4 @@
+
 import numpy as np
 import math
 import matplotlib.pyplot as plt
@@ -13,13 +14,47 @@ def display_alphabet(alphabet, values=None, a_vals=False, title="Alphabet conste
     N = math.ceil(math.log2(len(alphabet)))
     if a_vals:
         for i, value in enumerate(rect):
-            plt.annotate(xy=value+[0.01, 0.01], s=format(i, f'0{N}b'))
+            plt.annotate(xy=value+[0.01, 0.01], text=format(i, f'0{N}b'))
     plt.xlabel('Real')
     plt.ylabel('Imaginary')
     plt.grid()
     plt.show()
 
 
+def bit_matrix2one_hot(matrix: np.ndarray) -> np.ndarray:
+    """
+    Returns a copy of bit encoded matrix to one hot matrix. A row examples:
+    [1010] (decimal 10) => [0000 0100 0000 0000]
+    [0011] (decimal  3) => [0000 0000 0000 1000]
+    each number represents true/false value in column
+    """
+    N = matrix.shape[1]
+    encoder = 2**np.arange(N)
+    values = np.dot(matrix, encoder)
+    result = np.zeros((matrix.shape[0], 2**N), dtype=bool)
+    result[np.arange(matrix.shape[0]), values] = True
+    return result
+
+
+def one_hot2bit_matrix(matrix: np.ndarray) -> np.ndarray:
+    """
+    Returns a copy of one hot matrix to bit encoded matrix. A row examples:
+    [0000 0100 0000 0000] => [1010] (decimal 10)
+    [0000 0000 0000 1000] => [0011] (decimal  3)
+    each number represents true/false value in column
+    """
+    N = math.ceil(math.log2(matrix.shape[1]))
+    values = np.dot(matrix, np.arange(2**N))
+    return int2bit_array(values, N)
+
+
+def int2bit_array(int_arr: np.ndarray, N: int) -> np.ndarray:
+    x0 = np.array([int_arr], dtype=np.uint8)
+    x1 = np.unpackbits(x0.T, bitorder='little', axis=1)
+    result = x1[:, :N].astype(bool)  # , indices
+    return result
+
+
 def polar2rect(array, amp_column=0, phase_column=1) -> np.ndarray:
     """
     Return copy of array with amp_column and phase_column as polar coordinates replaced by rectangular coordinates
@@ -68,3 +103,5 @@ def generate_random_bit_array(size):
     arr = np.concatenate(p)
     np.random.shuffle(arr)
     return arr
+
+

+ 258 - 0
models/autoencoder.py

@@ -0,0 +1,258 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+
+from sklearn.metrics import accuracy_score
+from sklearn.model_selection import train_test_split
+from tensorflow.keras import layers, losses
+from tensorflow.keras.models import Model
+from tensorflow.python.keras.layers import LeakyReLU, ReLU
+
+from functools import partial
+import misc
+import defs
+from models import basic
+import os
+
+latent_dim = 64
+
+print("# GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
+
+
+class AutoencoderMod(defs.Modulator):
+    def __init__(self, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+
+    def forward(self, binary: np.ndarray):
+        reshaped = binary.reshape((-1, self.N))
+        reshaped_ho = misc.bit_matrix2one_hot(reshaped)
+        encoded = self.autoencoder.encoder(reshaped_ho)
+        x = encoded.numpy()
+        x2 = x * 2 - 1
+
+        f = np.zeros(x2.shape[0])
+        x3 = misc.rect2polar(np.c_[x2[:, 0], x2[:, 1], f])
+        return basic.RFSignal(x3)
+
+
+class AutoencoderDemod(defs.Demodulator):
+    def __init__(self, autoencoder):
+        super().__init__(2 ** autoencoder.N)
+        self.autoencoder = autoencoder
+
+    def forward(self, values: defs.Signal) -> np.ndarray:
+        decoded = self.autoencoder.decoder(values.rect).numpy()
+        result = misc.int2bit_array(decoded.argmax(axis=1), self.N)
+        return result.reshape(-1, )
+
+
+class Autoencoder(Model):
+    def __init__(self, N, channel, signal_dim=2):
+        super(Autoencoder, self).__init__()
+        self.N = N
+        self.encoder = tf.keras.Sequential()
+        self.encoder.add(tf.keras.Input(shape=(2 ** N,), dtype=bool))
+        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+        self.encoder.add(LeakyReLU(alpha=0.001))
+        # self.encoder.add(layers.Dropout(0.2))
+        self.encoder.add(layers.Dense(units=2 ** (N + 1)))
+        self.encoder.add(LeakyReLU(alpha=0.001))
+        self.encoder.add(layers.Dense(units=signal_dim, activation="tanh"))
+        # self.encoder.add(layers.ReLU(max_value=1.0))
+
+        self.decoder = tf.keras.Sequential()
+        self.decoder.add(tf.keras.Input(shape=(signal_dim,)))
+        self.decoder.add(layers.Dense(units=2 ** (N + 1)))
+        # leaky relu with alpha=1 gives by far best results
+        self.decoder.add(LeakyReLU(alpha=1))
+        self.decoder.add(layers.Dense(units=2 ** N, activation="softmax"))
+
+        # self.randomiser = tf.random_normal_initializer(mean=0.0, stddev=0.1, seed=None)
+
+        self.mod = None
+        self.demod = None
+        self.compiled = False
+
+        if isinstance(channel, int) or isinstance(channel, float):
+            self.channel = basic.AWGNChannel(channel)
+        else:
+            if not hasattr(channel, 'forward_tensor'):
+                raise ValueError("Channel has no forward_tensor function")
+            if not callable(channel.forward_tensor):
+                raise ValueError("Channel.forward_tensor is not callable")
+            self.channel = channel
+
+        # self.decoder.add(layers.Softmax(units=4, dtype=bool))
+
+        # [
+        #     layers.Input(shape=(28, 28, 1)),
+        #     layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),
+        #     layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)
+        # ])
+        # self.decoder = tf.keras.Sequential([
+        #     layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),
+        #     layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),
+        #     layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
+        # ])
+
+    def call(self, x, **kwargs):
+        signal = self.encoder(x)
+        signal = signal * 2 - 1
+        signal = self.channel.forward_tensor(signal)
+        # encoded = encoded * 2 - 1
+        # encoded = tf.clip_by_value(encoded, clip_value_min=0, clip_value_max=1, name=None)
+        # noise = self.randomiser(shape=(-1, 2), dtype=tf.float32)
+        # noise = np.random.normal(0, 1, (1, 2)) * self.noise
+        # noisy = tf.convert_to_tensor(noise, dtype=tf.float32)
+        decoded = self.decoder(signal)
+        return decoded
+
+    def fit_encoder(self, modulation, sample_size, train_size=0.8, epochs=1, batch_size=1, shuffle=False):
+        alphabet = basic.load_alphabet(modulation, polar=False)
+
+        if not alphabet.shape[0] == self.N ** 2:
+            raise Exception("Cardinality of modulation scheme is different from cardinality of autoencoder!")
+
+        x_train = np.random.randint(self.N ** 2, size=int(sample_size * train_size))
+        y_train = alphabet[x_train]
+        x_train_ho = np.zeros((int(sample_size * train_size), self.N ** 2))
+        for idx, x in np.ndenumerate(x_train):
+            x_train_ho[idx, x] = 1
+
+        x_test = np.random.randint(self.N ** 2, size=int(sample_size * (1 - train_size)))
+        y_test = alphabet[x_test]
+        x_test_ho = np.zeros((int(sample_size * (1 - train_size)), self.N ** 2))
+        for idx, x in np.ndenumerate(x_test):
+            x_test_ho[idx, x] = 1
+
+        self.encoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.encoder.fit(x_train_ho, y_train,
+                         epochs=epochs,
+                         batch_size=batch_size,
+                         shuffle=shuffle,
+                         validation_data=(x_test_ho, y_test))
+
+    def fit_decoder(self, modulation, samples):
+        samples = int(samples * 1.3)
+        demod = basic.AlphabetDemod(modulation, 0)
+        x = np.random.rand(samples, 2) * 2 - 1
+        x = x.reshape((-1, 2))
+        f = np.zeros(x.shape[0])
+        xf = np.c_[x[:, 0], x[:, 1], f]
+        y = demod.forward(basic.RFSignal(misc.rect2polar(xf)))
+        y_ho = misc.bit_matrix2one_hot(y.reshape((-1, 4)))
+
+        X_train, X_test, y_train, y_test = train_test_split(x, y_ho)
+        self.decoder.compile(optimizer='adam', loss=tf.keras.losses.MeanSquaredError())
+        self.decoder.fit(X_train, y_train, shuffle=False, validation_data=(X_test, y_test))
+        y_pred = self.decoder(X_test).numpy()
+        y_pred2 = np.zeros(y_test.shape, dtype=bool)
+        y_pred2[np.arange(y_pred2.shape[0]), np.argmax(y_pred, axis=1)] = True
+
+        print("Decoder accuracy: %.4f" % accuracy_score(y_pred2, y_test))
+
+    def train(self, samples=1e6):
+        if samples % self.N:
+            samples += self.N - (samples % self.N)
+        x_train = misc.generate_random_bit_array(samples).reshape((-1, self.N))
+        x_train_ho = misc.bit_matrix2one_hot(x_train)
+
+        test_samples = samples * 0.3
+        if test_samples % self.N:
+            test_samples += self.N - (test_samples % self.N)
+        x_test_array = misc.generate_random_bit_array(test_samples)
+        x_test = x_test_array.reshape((-1, self.N))
+        x_test_ho = misc.bit_matrix2one_hot(x_test)
+
+        if not self.compiled:
+            self.compile(optimizer='adam', loss=losses.MeanSquaredError())
+            self.compiled = True
+
+        self.fit(x_train_ho, x_train_ho, shuffle=False, validation_data=(x_test_ho, x_test_ho))
+        # encoded_data = self.encoder(x_test_ho)
+        # decoded_data = self.decoder(encoded_data).numpy()
+
+    def get_modulator(self):
+        if self.mod is None:
+            self.mod = AutoencoderMod(self)
+        return self.mod
+
+    def get_demodulator(self):
+        if self.demod is None:
+            self.demod = AutoencoderDemod(self)
+        return self.demod
+
+
+def view_encoder(encoder, N, samples=1000):
+    test_values = misc.generate_random_bit_array(samples).reshape((-1, N))
+    test_values_ho = misc.bit_matrix2one_hot(test_values)
+    mvector = np.array([2 ** i for i in range(N)], dtype=int)
+    symbols = (test_values * mvector).sum(axis=1)
+    encoded = encoder(test_values_ho).numpy()
+    # encoded = misc.polar2rect(encoded)
+    for i in range(2 ** N):
+        xy = encoded[symbols == i]
+        plt.plot(xy[:, 0], xy[:, 1], 'x', markersize=12, label=format(i, f'0{N}b'))
+        plt.annotate(xy=[xy[:, 0].mean() + 0.01, xy[:, 1].mean() + 0.01], text=format(i, f'0{N}b'))
+    plt.xlabel('Real')
+    plt.ylabel('Imaginary')
+    plt.title("Autoencoder generated alphabet")
+    # plt.legend()
+    plt.show()
+
+    pass
+
+
+if __name__ == '__main__':
+    # (x_train, _), (x_test, _) = fashion_mnist.load_data()
+    #
+    # x_train = x_train.astype('float32') / 255.
+    # x_test = x_test.astype('float32') / 255.
+    #
+    # print(f"Train data: {x_train.shape}")
+    # print(f"Test data: {x_test.shape}")
+
+    n = 4
+
+    # samples = 1e6
+    # x_train = misc.generate_random_bit_array(samples).reshape((-1, n))
+    # x_train_ho = misc.bit_matrix2one_hot(x_train)
+    # x_test_array = misc.generate_random_bit_array(samples * 0.3)
+    # x_test = x_test_array.reshape((-1, n))
+    # x_test_ho = misc.bit_matrix2one_hot(x_test)
+
+    autoencoder = Autoencoder(n, -8)
+    autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
+
+    autoencoder.fit_encoder(modulation='16qam',
+                            sample_size=2e6,
+                            train_size=0.8,
+                            epochs=1,
+                            batch_size=256,
+                            shuffle=True)
+
+    view_encoder(autoencoder.encoder, n)
+    autoencoder.fit_decoder(modulation='16qam', samples=2e6)
+    autoencoder.train()
+    view_encoder(autoencoder.encoder, n)
+
+    # view_encoder(autoencoder.encoder, n)
+    # view_encoder(autoencoder.encoder, n)
+
+
+    # autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())
+    #
+    # autoencoder.fit(x_train_ho, x_train_ho,
+    #                 epochs=1,
+    #                 shuffle=False,
+    #                 validation_data=(x_test_ho, x_test_ho))
+    #
+    # encoded_data = autoencoder.encoder(x_test_ho)
+    # decoded_data = autoencoder.decoder(encoded_data).numpy()
+    #
+    # result = misc.int2bit_array(decoded_data.argmax(axis=1), n)
+    # print("Accuracy: %.4f" % accuracy_score(x_test_array, result.reshape(-1, )))
+    # view_encoder(autoencoder.encoder, n)
+
+    pass

+ 107 - 93
models/basic.py

@@ -3,71 +3,80 @@ import numpy as np
 import math
 import misc
 from scipy.spatial import cKDTree
-
-def _make_gray(n):
-    if n <= 0:
-        return []
-    arr = ['0', '1']
-    i = 2
-    while True:
-        if i >= 1 << n:
-            break
-        for j in range(i - 1, -1, -1):
-            arr.append(arr[j])
-        for j in range(i):
-            arr[j] = "0" + arr[j]
-        for j in range(i, 2 * i):
-            arr[j] = "1" + arr[j]
-        i = i << 1
-    return list(map(lambda x: int(x, 2), arr))
-
-
-def _gen_mary_alphabet(size, gray=True, polar=True):
-    alphabet = np.zeros((size, 2))
-    N = math.ceil(math.sqrt(size))
-
-    # if sqrt(size) != size^2 (not a perfect square),
-    # skip defines how many corners to cut off.
-    skip = 0
-    if N ** 2 > size:
-        skip = int(math.sqrt((N ** 2 - size) // 4))
-
-    step = 2 / (N - 1)
-    skipped = 0
-    for x in range(N):
-        for y in range(N):
-            i = x * N + y - skipped
-            if i >= size:
-                break
-            # Reverse y every odd column
-            if x % 2 == 0 and N < 4:
-                y = N - y - 1
-            if skip > 0:
-                if (x < skip or x + 1 > N - skip) and \
-                        (y < skip or y + 1 > N - skip):
-                    skipped += 1
-                    continue
-            # Exception for 3-ary alphabet, skip centre point
-            if size == 8 and x == 1 and y == 1:
-                skipped += 1
+from os import path
+import tensorflow as tf
+
+ALPHABET_DIR = "./alphabets"
+
+
+def load_alphabet(name, polar=True):
+    apath = path.join(ALPHABET_DIR, name + '.a')
+    if not path.exists(apath):
+        raise ValueError(f"Alphabet '{name}' does not exist")
+    data = []
+    indexes = []
+    with open(apath, 'r') as f:
+        header = f.readline().lower()
+        if 'd' not in header and 'r' not in header:
+            raise ValueError(f"Alphabet {name} header does not specify valid format")
+        for i, row in enumerate(f.readlines()):
+            row = row.strip()
+            if len(row) == 0:
                 continue
-            alphabet[i, :] = [step * x - 1, step * y - 1]
-    if gray:
-        shape = alphabet.shape
-        d1 = 4 if N > 4 else 2 ** N // 4
-        g1 = np.array([0, 1, 3, 2])
-        g2 = g1[:d1]
-        hypershape = (d1, 4, 2)
-        if N > 4:
-            hypercube = alphabet.reshape(hypershape + (N-4, ))
-            hypercube = hypercube[:, g1, :, :][g2, :, :, :]
-        else:
-            hypercube = alphabet.reshape(hypershape)
-            hypercube = hypercube[:, g1, :][g2, :, :]
-        alphabet = hypercube.reshape(shape)
+            cols = row.split(',')
+            try:
+                if len(cols) == 3:
+                    base = 2
+                    if 'i' in header:
+                        base = 10
+                    indexes.append(int(cols[0], base))
+                    x = float(cols[1])
+                    y = float(cols[2])
+                elif len(cols) == 2:
+                    indexes.append(i)
+                    x = float(cols[0])
+                    y = float(cols[1])
+                else:
+                    raise ValueError()
+                if 'd' in header:
+                    p = y * math.pi / 180
+                    y = math.sin(p) * x
+                    x = math.cos(p) * x
+                data.append((x, y))
+            except ValueError:
+                raise ValueError(f"Alphabet {name} line {i + 1}: '{row}' has invalid values")
+
+    data2 = [None] * len(data)
+    for i, d in enumerate(data):
+        data2[indexes[i]] = d
+    arr = np.array(data2, dtype=float)
     if polar:
-        alphabet = misc.rect2polar(alphabet)
-    return alphabet
+        arr = misc.rect2polar(arr)
+    return arr
+
+
+class RFSignal(defs.Signal):
+    def __init__(self, array: np.ndarray):
+        self.amplitude = array[:, 0]
+        self.phase = array[:, 1]
+        self.frequency = array[:, 2]
+        self.symbols = array.shape[0]
+
+    @property
+    def rect(self) -> np.ndarray:
+        return misc.polar2rect(np.c_[self.amplitude, self.phase])
+
+    def set_rect_xy(self, x_mat: np.ndarray, y_mat: np.ndarray):
+        self.set_rect(np.c_[x_mat, y_mat])
+
+    def set_rect(self, mat: np.ndarray):
+        polar = misc.rect2polar(mat)
+        self.amplitude = polar[:, 0]
+        self.phase = polar[:, 1]
+
+    @property
+    def apf(self):
+        return np.c_[self.amplitude, self.phase, self.frequency]
 
 
 class BypassChannel(defs.Channel):
@@ -83,12 +92,17 @@ class AWGNChannel(defs.Channel):
         super().__init__(**kwargs)
         self.noise = 10 ** (noise_level / 10)
 
-    def forward(self, values):
-        a = np.random.normal(0, 1, values.shape[0]) * self.noise
-        p = np.random.normal(0, 1, values.shape[0]) * self.noise
-        f = np.zeros(values.shape[0])
-        noise_mat = np.c_[a, p, f]
-        return values + noise_mat
+    def forward(self, values: RFSignal) -> RFSignal:
+        values.set_rect_xy(
+            values.rect_x + np.random.normal(0, 1, values.symbols) * self.noise,
+            values.rect_y + np.random.normal(0, 1, values.symbols) * self.noise,
+        )
+        return values
+
+    def forward_tensor(self, tensor: tf.Tensor) -> tf.Tensor:
+        noise = tf.random.normal([2], mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None)
+        tensor += noise * self.noise
+        return tensor
 
 
 class BPSKMod(defs.Modulator):
@@ -97,12 +111,12 @@ class BPSKMod(defs.Modulator):
         super().__init__(2, **kwargs)
         self.f = carrier_f
 
-    def forward(self, binary: np.ndarray):
+    def forward(self, binary):
         a = np.ones(binary.shape[0])
         p = np.zeros(binary.shape[0])
         p[binary == True] = np.pi
         f = np.zeros(binary.shape[0]) + self.f
-        return np.c_[a, p, f]
+        return RFSignal(np.c_[a, p, f])
 
 
 class BPSKDemod(defs.Demodulator):
@@ -119,22 +133,22 @@ class BPSKDemod(defs.Demodulator):
     def forward(self, values):
         # TODO: Channel noise simulator for frequency component?
         # for now we only care about amplitude and phase
-        ap = np.delete(values, 2, 1)
-        ap = misc.polar2rect(ap)
+        # ap = np.delete(values, 2, 1)
+        # ap = misc.polar2rect(ap)
 
-        result = np.ones(values.shape[0], dtype=bool)
-        result[ap[:, 0] > 0] = False
+        result = np.ones(values.symbols, dtype=bool)
+        result[values.rect_x[:, 0] > 0] = False
         return result
 
 
-class MaryMod(defs.Modulator):
+class AlphabetMod(defs.Modulator):
 
-    def __init__(self, N, carrier_f, gray=True):
-        if N < 2:
-            raise ValueError("M-ary modulator N value has to be larger than 1")
-        super().__init__(2 ** N)
+    def __init__(self, modulation, carrier_f):
+        # if N < 2:
+        #     raise ValueError("M-ary modulator N value has to be larger than 1")
+        self.alphabet = load_alphabet(modulation)
+        super().__init__(self.alphabet.shape[0])
         self.f = carrier_f
-        self.alphabet = _gen_mary_alphabet(self.alphabet_size, gray)
         self.mult_mat = np.array([2 ** i for i in range(self.N)])
 
     def forward(self, binary):
@@ -148,26 +162,26 @@ class MaryMod(defs.Modulator):
         a = values[:, 0]
         p = values[:, 1]
         f = np.zeros(reshaped.shape[0]) + self.f
-        return np.c_[a, p, f]  #, indices
+        return RFSignal(np.c_[a, p, f])  # , indices
 
 
-class MaryDemod(defs.Demodulator):
+class AlphabetDemod(defs.Demodulator):
 
-    def __init__(self, N, carrier_f, gray=True):
-        if N < 2:
-            raise ValueError("M-ary modulator N value has to be larger than 1")
-        super().__init__(2 ** N)
+    def __init__(self, modulation, carrier_f):
+        # if N < 2:
+        #     raise ValueError("M-ary modulator N value has to be larger than 1")
+        self.alphabet = load_alphabet(modulation, polar=False)
+        super().__init__(self.alphabet.shape[0])
         self.f = carrier_f
-        self.N = N
-        self.alphabet = _gen_mary_alphabet(self.alphabet_size, gray=gray, polar=False)
+        # self.alphabet = _gen_mary_alphabet(self.alphabet_size, gray=gray, polar=False)
         self.ktree = cKDTree(self.alphabet)
 
     def forward(self, binary):
-        binary = binary[:, :2]  # ignore frequency
-        rbin = misc.polar2rect(binary)
-        indices = self.ktree.query(rbin)[1]
+        # binary = binary[:, :2]  # ignore frequency
+        # rbin = misc.polar2rect(binary)
+        indices = self.ktree.query(binary.rect)[1]
 
         # Converting indices to bite array
         # FIXME: unpackbits requires 8bit inputs, thus largest demodulation is 256-QAM
         values = np.unpackbits(np.array([indices], dtype=np.uint8).T, bitorder='little', axis=1)
-        return values[:, :self.N].reshape((-1,)).astype(bool)  #, indices
+        return values[:, :self.N].reshape((-1,)).astype(bool)  # , indices

+ 369 - 0
models/end_to_end.py

@@ -0,0 +1,369 @@
+import math
+
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.preprocessing import OneHotEncoder
+from tensorflow.keras import layers, losses
+
+
+class ExtractCentralMessage(layers.Layer):
+    def __init__(self, messages_per_block, samples_per_symbol):
+        """
+        A keras layer that extracts the central message(symbol) in a block.
+
+        :param messages_per_block: Total number of messages in transmission block
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        """
+        super(ExtractCentralMessage, self).__init__()
+
+        temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
+        i = np.identity(samples_per_symbol)
+        begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
+        end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
+        temp_w[begin:end, :] = i
+
+        self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
+
+    def call(self, inputs, **kwargs):
+        return tf.matmul(inputs, self.w)
+
+
+class AwgnChannel(layers.Layer):
+    def __init__(self, rx_stddev=0.1):
+        """
+        A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
+        being applied every time the call function is called.
+
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        """
+        super(AwgnChannel, self).__init__()
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+
+    def call(self, inputs, **kwargs):
+        return self.noise_layer.call(inputs, training=True)
+
+
+class DigitizationLayer(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 lpf_cutoff=32e9,
+                 q_stddev=0.1):
+        """
+        This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
+        artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(DigitizationLayer, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(q_stddev)
+        freq = np.fft.fftfreq(num_of_samples, d=1/fs)
+        temp = np.ones(freq.shape)
+
+        for idx, val in np.ndenumerate(freq):
+            if np.abs(val) > lpf_cutoff:
+                temp[idx] = 0
+
+        self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
+
+    def call(self, inputs, **kwargs):
+        complex_in = tf.cast(inputs, dtype=tf.complex64)
+        val_f = tf.signal.fft(complex_in)
+        filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
+        filtered_t = tf.signal.ifft(filtered_f)
+        real_t = tf.cast(filtered_t, dtype=tf.float32)
+        noisy = self.noise_layer.call(real_t, training=True)
+        return noisy
+
+
+class OpticalChannel(layers.Layer):
+    def __init__(self,
+                 fs,
+                 num_of_samples,
+                 dispersion_factor,
+                 fiber_length,
+                 lpf_cutoff=32e9,
+                 rx_stddev=0.01,
+                 q_stddev=0.01):
+        """
+        A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
+        ADC/DAC as well as additive white gaussian noise in optical communication channels.
+
+        :param fs: Sampling frequency of the simulation in Hz
+        :param num_of_samples: Total number of samples in the input
+        :param dispersion_factor: Dispersion factor in s^2/km
+        :param fiber_length: Length of fiber to model in km
+        :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
+        :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
+        :param q_stddev: Standard deviation of quantization noise at ADC/DAC
+        """
+        super(OpticalChannel, self).__init__()
+
+        self.noise_layer = layers.GaussianNoise(rx_stddev)
+        self.digitization_layer = DigitizationLayer(fs=fs,
+                                                    num_of_samples=num_of_samples,
+                                                    lpf_cutoff=lpf_cutoff,
+                                                    q_stddev=q_stddev)
+        self.flatten_layer = layers.Flatten()
+
+        self.fs = fs
+        self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
+        self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
+
+    def call(self, inputs, **kwargs):
+        # DAC LPF and noise
+        dac_out = self.digitization_layer(inputs)
+
+        # Chromatic Dispersion
+        complex_val = tf.cast(dac_out, dtype=tf.complex128)
+        val_f = tf.signal.fft(complex_val)
+        disp_f = tf.math.multiply(val_f, self.multiplier)
+        disp_t = tf.signal.ifft(disp_f)
+
+        # Squared-Law Detection
+        pd_out = tf.square(tf.abs(disp_t))
+
+        # Casting back to floatx
+        real_val = tf.cast(pd_out, dtype=tf.float32)
+
+        # Adding photo-diode receiver noise
+        rx_signal = self.noise_layer.call(real_val, training=True)
+
+        # ADC LPF and noise
+        adc_out = self.digitization_layer(rx_signal)
+
+        return adc_out
+
+
+class EndToEndAutoencoder(tf.keras.Model):
+    def __init__(self,
+                 cardinality,
+                 samples_per_symbol,
+                 messages_per_block,
+                 channel):
+        """
+        The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
+        of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
+        interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
+
+        :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
+        :param samples_per_symbol: Number of samples per transmitted symbol
+        :param messages_per_block: Total number of messages in transmission block
+        :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
+        """
+        super(EndToEndAutoencoder, self).__init__()
+
+        # Labelled M in paper
+        self.cardinality = cardinality
+        # Labelled n in paper
+        self.samples_per_symbol = samples_per_symbol
+        # Labelled N in paper
+        if messages_per_block % 2 == 0:
+            messages_per_block += 1
+        self.messages_per_block = messages_per_block
+        # Channel Model Layer
+        if isinstance(channel, layers.Layer):
+            self.channel = tf.keras.Sequential([
+                layers.Flatten(),
+                channel,
+                ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
+            ])
+        else:
+            raise TypeError("Channel must be a subclass of keras.layers.layer!")
+
+        # Encoding Neural Network
+        self.encoder = tf.keras.Sequential([
+            layers.Input(shape=(self.messages_per_block, self.cardinality)),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.samples_per_symbol),
+            layers.ReLU(max_value=1.0)
+        ])
+
+        # Decoding Neural Network
+        self.decoder = tf.keras.Sequential([
+            layers.Dense(self.samples_per_symbol, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(2 * self.cardinality, activation='relu'),
+            layers.Dense(self.cardinality, activation='softmax')
+        ])
+
+    def generate_random_inputs(self, num_of_blocks, return_vals=False):
+        """
+        A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
+
+        :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
+        consecutively to model ISI. The central message in a block is returned as the label for training.
+        :param return_vals: If true, the raw decimal values of the input sequence will be returned
+        """
+        rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
+
+        cat = [np.arange(self.cardinality)]
+        enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
+
+        out = enc.fit_transform(rand_int)
+        out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
+
+        mid_idx = int((self.messages_per_block-1)/2)
+
+        if return_vals:
+            out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
+            return out_val, out_arr, out_arr[:, mid_idx, :]
+
+        return out_arr, out_arr[:, mid_idx, :]
+
+    def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
+        """
+        Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
+
+        :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
+        :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
+        :param train_size: Float less than 1 representing the proportion of the dataset to use for training
+        :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
+        """
+        X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
+        X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
+
+        opt = tf.keras.optimizers.Adam(learning_rate=lr)
+
+        self.compile(optimizer=opt,
+                     loss=losses.BinaryCrossentropy(),
+                     metrics=['accuracy'],
+                     loss_weights=None,
+                     weighted_metrics=None,
+                     run_eagerly=False
+                     )
+
+        self.fit(x=X_train,
+                 y=y_train,
+                 batch_size=batch_size,
+                 epochs=1,
+                 shuffle=True,
+                 validation_data=(X_test, y_test)
+                 )
+
+    def view_encoder(self):
+        '''
+        A method that views the learnt encoder for each distint message. This is displayed as a plot with  asubplot for
+        each image.
+        '''
+        # Generate inputs for encoder
+        messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
+
+        mid_idx = int((self.messages_per_block-1)/2)
+
+        idx = 0
+        for msg in messages:
+            msg[mid_idx, idx] = 1
+            idx += 1
+
+        # Pass input through encoder and select middle messages
+        encoded = self.encoder(messages)
+        enc_messages = encoded[:, mid_idx, :]
+
+        # Compute subplot grid layout
+        i = 0
+        while 2**i < self.cardinality**0.5:
+            i += 1
+
+        num_x = int(2**i)
+        num_y = int(self.cardinality / num_x)
+
+        # Plot all symbols
+        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
+
+        t = np.arange(self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t/self.channel.layers[1].fs
+
+        sym_idx = 0
+        for y in range(num_y):
+            for x in range(num_x):
+                axs[y, x].plot(t, enc_messages[sym_idx], 'x')
+                axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
+                sym_idx += 1
+
+        for ax in axs.flat:
+            ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
+
+        for ax in axs.flat:
+            ax.label_outer()
+
+        plt.show()
+        pass
+
+    def view_sample_block(self):
+        '''
+        Generates a random string of input message and encodes them. In addition to this, the output is passed through
+        digitization layer without any quantization noise for the low pass filtering.
+        '''
+        # Generate a random block of messages
+        val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
+
+        # Encode and flatten the messages
+        enc = self.encoder(inp)
+        flat_enc = layers.Flatten()(enc)
+
+        # Instantiate LPF layer
+        lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
+                                num_of_samples=self.messages_per_block*self.samples_per_symbol,
+                                q_stddev=0)
+
+        # Apply LPF
+        lpf_out = lpf(flat_enc)
+
+        # Time axis
+        t = np.arange(self.messages_per_block*self.samples_per_symbol)
+        if isinstance(self.channel.layers[1], OpticalChannel):
+            t = t / self.channel.layers[1].fs
+
+        # Plot the concatenated symbols before and after LPF
+        plt.figure(figsize=(2*self.messages_per_block, 6))
+
+        for i in range(1, self.messages_per_block):
+            plt.axvline(x=t[i*self.samples_per_symbol], color='black')
+
+        plt.plot(t, flat_enc.numpy().T, 'x')
+        plt.plot(t, lpf_out.numpy().T)
+        plt.ylim((0, 1))
+        plt.xlim((t.min(), t.max()))
+        plt.title(str(val[0, :, 0]))
+        plt.show()
+        pass
+
+    def call(self, inputs, training=None, mask=None):
+        tx = self.encoder(inputs)
+        rx = self.channel(tx)
+        outputs = self.decoder(rx)
+        return outputs
+
+
+if __name__ == '__main__':
+
+    SAMPLING_FREQUENCY = 336e9
+    CARDINALITY = 32
+    SAMPLES_PER_SYMBOL = 24
+    MESSAGES_PER_BLOCK = 9
+    DISPERSION_FACTOR = -21.7 * 1e-24
+    FIBER_LENGTH = 50
+
+    optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
+                                     num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
+                                     dispersion_factor=DISPERSION_FACTOR,
+                                     fiber_length=FIBER_LENGTH)
+
+    ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
+                                   samples_per_symbol=SAMPLES_PER_SYMBOL,
+                                   messages_per_block=MESSAGES_PER_BLOCK,
+                                   channel=optical_channel)
+
+    ae_model.train(num_of_blocks=1e6, batch_size=100)
+    ae_model.view_encoder()
+    ae_model.view_sample_block()
+
+    pass

+ 177 - 0
models/optical_channel.py

@@ -0,0 +1,177 @@
+import matplotlib.pyplot as plt
+
+import defs
+import numpy as np
+import math
+from numpy.fft import fft, fftfreq, ifft
+from commpy.filters import rrcosfilter, rcosfilter, rectfilter
+
+from models import basic
+
+
+class OpticalChannel(defs.Channel):
+    def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
+                 sqrt_out=False, show_graphs=False, **kwargs):
+        """
+        :param noise_level: Noise level in dB
+        :param dispersion: dispersion coefficient is ps^2/km
+        :param symbol_rate: Symbol rate of modulated signal in Hz
+        :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
+        :param length: fibre length in km
+        :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
+        :param sqrt_out: Take the root of the out to compensate for photodiode detection
+        :param show_graphs: if graphs should be displayed or not
+
+        Optical Channel class constructor
+        """
+        super().__init__(**kwargs)
+        self.noise = 10 ** (noise_level / 10)
+
+        self.dispersion = dispersion * 1e-24  # Converting from ps^2/km to s^2/km
+        self.symbol_rate = symbol_rate
+        self.symbol_period = 1 / self.symbol_rate
+        self.sample_rate = sample_rate
+        self.sample_period = 1 / self.sample_rate
+        self.length = length
+        self.pulse_shape = pulse_shape.strip().lower()
+        self.sqrt_out = sqrt_out
+        self.show_graphs = show_graphs
+
+    def __get_time_domain(self, symbol_vals):
+        samples_per_symbol = int(self.sample_rate / self.symbol_rate)
+        samples = int(symbol_vals.shape[0] * samples_per_symbol)
+
+        symbol_impulse = np.zeros(samples)
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        for i in range(symbol_vals.shape[0]):
+            symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
+
+        if self.pulse_shape == 'rrcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        elif self.pulse_shape == 'rcos':
+            self.filter_samples = 5 * samples_per_symbol
+            self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
+        else:
+            self.filter_samples = samples_per_symbol
+            self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
+
+        val_t = np.convolve(symbol_impulse, self.h_filter)
+        t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
+
+        return t, val_t
+
+    def __time_to_frequency(self, values):
+        val_f = fft(values)
+        f = fftfreq(values.shape[-1])*self.sample_rate
+        return f, val_f
+
+    def __frequency_to_time(self, values):
+        val_t = ifft(values)
+        t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
+        return t, val_t
+
+    def __apply_dispersion(self, values):
+        # Obtain fft
+        f, val_f = self.__time_to_frequency(values)
+
+        if self.show_graphs:
+            plt.plot(f, val_f)
+            plt.title('frequency domain (pre-distortion)')
+            plt.show()
+
+        # Apply distortion
+        dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
+
+        if self.show_graphs:
+            plt.plot(f, dist_val_f)
+            plt.title('frequency domain (post-distortion)')
+            plt.show()
+
+        # Inverse fft
+        t, val_t = self.__frequency_to_time(dist_val_f)
+
+        return t, val_t
+
+    def __photodiode_detection(self, values):
+        t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
+        val_t = np.power(np.absolute(values), 2)
+        return t, val_t
+
+    def forward(self, values):
+        if hasattr(values, 'apf'):
+            values = values.apf
+        # Converting APF representation to time-series
+        t, val_t = self.__get_time_domain(values)
+
+        if self.show_graphs:
+            plt.plot(t, val_t)
+            plt.title('time domain (raw)')
+            plt.show()
+
+        # Adding AWGN
+        val_t += np.random.normal(0, 1, val_t.shape) * self.noise
+
+        if self.show_graphs:
+            plt.plot(t, val_t)
+            plt.title('time domain (AWGN)')
+            plt.show()
+
+        # Applying chromatic dispersion
+        t, val_t = self.__apply_dispersion(val_t)
+
+        if self.show_graphs:
+            plt.plot(t, val_t)
+            plt.title('time domain (post-distortion)')
+            plt.show()
+
+        # Photodiode Detection
+        t, val_t = self.__photodiode_detection(val_t)
+
+        # Symbol Decisions
+        idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
+                        self.symbol_period/self.sample_period, dtype='int64')
+        t_descision = self.sample_period * idx
+
+        if self.show_graphs:
+            plt.plot(t, val_t)
+            plt.title('time domain (post-detection)')
+            plt.show()
+
+            plt.plot(t, val_t)
+            for xc in t_descision:
+                plt.axvline(x=xc, color='r')
+            plt.title('time domain (post-detection with decision times)')
+            plt.show()
+
+        # TODO: Implement Frequency/Phase Modulation
+
+        out = np.zeros(values.shape)
+
+        out[:, 0] = val_t[idx]
+        out[:, 1] = values[:, 1]
+        out[:, 2] = values[:, 2]
+
+        if self.sqrt_out:
+            out[:, 0] = np.sqrt(out[:, 0])
+
+        return basic.RFSignal(out)
+
+
+if __name__ == '__main__':
+    # Simple OOK modulation
+    num_of_symbols = 100
+    symbol_vals = np.zeros((num_of_symbols, 3))
+
+    symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
+    symbol_vals[:, 2] = 40e9
+
+    channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
+                             sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
+    v = channel.forward(symbol_vals)
+
+    rx = (v > 0.5).astype(int)
+    tru = np.sum(rx == symbol_vals[:, 0].astype(int))
+    print("Accuracy: {}".format(tru/num_of_symbols))

+ 32 - 0
photonics.py

@@ -0,0 +1,32 @@
+from defs import Channel
+import numpy as np
+
+
+class PhotonicCoder:
+
+    def encode(self, data: np.ndarray) -> np.ndarray:
+        """
+        """
+        raise NotImplemented("encode function not defined")
+
+    def decode(self, data: np.ndarray) -> np.ndarray:
+        """
+        """
+        raise NotImplemented("decode function not defined")
+
+
+class PhotonicSimulation:
+    """
+    This is a main class that will contain coder/channel and all
+    necessary useful methods to run/monitor simulation
+    """
+
+    def __init__(self, coder: PhotonicCoder, channel: Channel):
+        self.coder = coder
+        self.channel = channel
+
+    def run(self, data: np.ndarray) -> np.ndarray:
+        encoded = self.coder.encode(data)
+        transmitted = self.channel.forward(encoded)
+        decoded = self.coder.decode(transmitted)
+        return decoded

+ 15 - 0
tests/misc_test.py

@@ -0,0 +1,15 @@
+import misc
+import numpy as np
+
+
+def test_bit_matrix_one_hot():
+    for n in range(2, 8):
+        x0 = misc.generate_random_bit_array(100 * n)
+        x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
+        x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
+        assert np.array_equal(x0, x2)
+
+
+if __name__ == "__main__":
+    test_bit_matrix_one_hot()
+    print("Everything passed")