ソースを参照

Multithreaded SNR calc + SNR for optics

Also testing optic channel
Min 5 年 前
コミット
5d6c327cdf
1 ファイル変更99 行追加11 行削除
  1. 99 11
      main.py

+ 99 - 11
main.py

@@ -5,7 +5,13 @@ from sklearn.metrics import accuracy_score
 from models import basic
 from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
 import misc
+import math
+import os
 from models.autoencoder import Autoencoder, view_encoder
 from models.autoencoder import Autoencoder, view_encoder
+from models.optical_channel import OpticalChannel
+from multiprocessing import Pool
+
+CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
 
 
 
 
 def show_constellation(mod, chan, demod, samples=1000):
 def show_constellation(mod, chan, demod, samples=1000):
@@ -37,7 +43,7 @@ def get_ber(mod, chan, demod, samples=1000):
     return 1 - accuracy_score(x, x_demod)
     return 1 - accuracy_score(x, x_demod)
 
 
 
 
-def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
+def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
     ber_x = np.linspace(start, stop, steps)
     ber_x = np.linspace(start, stop, steps)
     ber_y = []
     ber_y = []
     for noise in ber_x:
     for noise in ber_x:
@@ -45,6 +51,47 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
     return ber_x, ber_y
 
 
 
 
+def __calc_ber(packed):
+    # This function has to be outside get_Optical_ber in order to be pickled by pool
+    mod, demod, noise, length, pulse_shape, samples = packed
+    tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
+                                length=length, pulse_shape=pulse_shape, sqrt_out=True)
+    return get_ber(mod, tx_channel, demod, samples=samples)
+
+
+def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
+    ber_x = np.linspace(start, stop, steps)
+    ber_y = []
+    print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
+    with Pool(CPU_COUNT) as pool:
+        packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
+        for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
+            ber_y.append(ber)
+            i += 1  # just offset by 1
+            print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i*100/len(ber_x):6.2f}%)", end='')
+    print()
+    return ber_x, ber_y
+
+
+def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
+    """
+    SNR for optics and RF should be calculated the same, that is A^2
+    Because P∝V² and P∝I²
+    """
+    x_mod = mod.forward(misc.generate_random_bit_array(samples * mod.N))
+    sig_amp = x_mod[:, 0]
+    sig_power = [A ** 2 for A in sig_amp]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+
+    noise_start = -start + av_sig_pow
+    noise_stop = -stop + av_sig_pow
+    ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
+    SNR = -ber_x + av_sig_pow
+    return SNR, ber_y
+
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
 
@@ -110,21 +157,62 @@ if __name__ == '__main__':
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
     #          label='AE 4bit -8dB')
     #          label='AE 4bit -8dB')
 
 
-    for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
-            AlphabetMod(scheme, 10e6),
-            AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+    # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
+    #     plt.plot(*get_SNR(
+    #         AlphabetMod(scheme, 10e6),
+    #         AlphabetDemod(scheme, 10e6),
+    #         samples=100e3,
+    #         steps=40,
+    #         start=-15
+    #     ), '-', label=scheme.upper())
+    # plt.yscale('log')
+    # plt.grid()
+    # plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    # plt.legend()
+    # plt.show()
+
+    for l in np.logspace(start=0, stop=3, num=5):
+        plt.plot(*get_SNR(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=2000,
+            steps=200,
+            start=-5,
+            stop=20,
+            length=l,
+            pulse_shape='rcos'
+        ), '-', label=(str(int(l))+'km'))
+
+    plt.yscale('log')
+    # plt.gca().invert_xaxis()
+    plt.grid()
+    plt.xlabel('SNR dB')
+    # plt.ylabel('BER')
+    plt.title("BER against Fiber length")
+    plt.legend()
+    plt.show()
+    # FIXME: Exit for now
+    exit()
+
+    for ps in ['rect']: #, 'rcos', 'rrcos']:
+        plt.plot(*get_Optical_ber(
+            AlphabetMod('4pam', 10e6),
+            AlphabetDemod('4pam', 10e6),
+            samples=30000,
             steps=40,
             steps=40,
-            start=-15
-        ), '-', label=scheme.upper())
+            start=-35,
+            # stop=10,
+            length=1,
+            pulse_shape=ps
+        ), '-', label=ps)
 
 
     plt.yscale('log')
     plt.yscale('log')
-    plt.gca().invert_xaxis()
     plt.grid()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
     plt.ylabel('BER')
+    plt.title("BER for different pulse shapes")
     plt.legend()
     plt.legend()
     plt.show()
     plt.show()
 
 
-    pass
+    pass