Pārlūkot izejas kodu

Multithreaded SNR calc + SNR for optics

Also testing optic channel
Min 5 gadi atpakaļ
vecāks
revīzija
5d6c327cdf
1 mainītis faili ar 99 papildinājumiem un 11 dzēšanām
  1. 99 11
      main.py

+ 99 - 11
main.py

@@ -5,7 +5,13 @@ from sklearn.metrics import accuracy_score
 from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
+import math
+import os
 from models.autoencoder import Autoencoder, view_encoder
+from models.optical_channel import OpticalChannel
+from multiprocessing import Pool
+
+CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
 
 
 def show_constellation(mod, chan, demod, samples=1000):
@@ -37,7 +43,7 @@ def get_ber(mod, chan, demod, samples=1000):
     return 1 - accuracy_score(x, x_demod)
 
 
-def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
+def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
     ber_x = np.linspace(start, stop, steps)
     ber_y = []
     for noise in ber_x:
@@ -45,6 +51,47 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
 
 
+def __calc_ber(packed):
+    # This function has to be outside get_Optical_ber in order to be pickled by pool
+    mod, demod, noise, length, pulse_shape, samples = packed
+    tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                length=length, pulse_shape=pulse_shape, sqrt_out=True)
+    return get_ber(mod, tx_channel, demod, samples=samples)
+
+
+def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
+    with Pool(CPU_COUNT) as pool:
+        packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
+        for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
+            ber_y.append(ber)
+            i += 1  # just offset by 1
+            print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i*100/len(ber_x):6.2f}%)", end='')
+    print()
+    return ber_x, ber_y
+
+
+def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
+    """
+    SNR for optics and RF should be calculated the same, that is A^2
+    Because P∝V² and P∝I²
+    """
+    x_mod = mod.forward(misc.generate_random_bit_array(samples * mod.N))
+    sig_amp = x_mod[:, 0]
+    sig_power = [A ** 2 for A in sig_amp]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+
+    noise_start = -start + av_sig_pow
+    noise_stop = -stop + av_sig_pow
+    ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
+    SNR = -ber_x + av_sig_pow
+    return SNR, ber_y
+
+
+
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
@@ -110,21 +157,62 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    for l in np.logspace(start=0, stop=3, num=5):
+        plt.plot(*get_SNR(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=2000,
+            steps=200,
+            start=-5,
+            stop=20,
+            length=l,
+            pulse_shape='rcos'
+        ), '-', label=(str(int(l))+'km'))
+
+    plt.yscale('log')
+    # plt.gca().invert_xaxis()
+    plt.grid()
+    plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    plt.title("BER against Fiber length")
+    plt.legend()
+    plt.show()
+    # FIXME: Exit for now
+    exit()
+
+    for ps in ['rect']: #, 'rcos', 'rrcos']:
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=30000,
             steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+            start=-35,
+            # stop=10,
+            length=1,
+            pulse_shape=ps
+        ), '-', label=ps)
 
     plt.yscale('log')
-    plt.gca().invert_xaxis()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.show()
 
-    pass
+    pass