Просмотр исходного кода

Merge remote-tracking branch 'origin/swarm_optimisation'

Min 5 лет назад
Родитель
Сommit
b7fdda8b86
1 измененных файлов с 15 добавлено и 4 удалено
  1. 15 4
      main.py

+ 15 - 4
main.py

@@ -5,6 +5,7 @@ from sklearn.metrics import accuracy_score
 from models import basic
 from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
 import misc
+import math
 from models.autoencoder import Autoencoder, view_encoder
 from models.autoencoder import Autoencoder, view_encoder
 
 
 
 
@@ -45,6 +46,16 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
     return ber_x, ber_y
 
 
 
 
+def get_SNR(mod, demod, samples=1000, start=-8, stop=5, steps=30):
+    ber_x, ber_y = get_AWGN_ber(mod, demod, samples, start, stop, steps)
+    x_mod = mod.forward(misc.generate_random_bit_array(samples*mod.N))
+    sig_amp = x_mod[:, 0]
+    sig_power = [A ** 2 for A in sig_amp]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+    SNR = (ber_x * -1) + av_sig_pow
+    return SNR, ber_y
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
 
@@ -111,18 +122,18 @@ if __name__ == '__main__':
     #          label='AE 4bit -8dB')
     #          label='AE 4bit -8dB')
 
 
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
+        plt.plot(*get_SNR(
             AlphabetMod(scheme, 10e6),
             AlphabetMod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+            samples=100e3,
             steps=40,
             steps=40,
             start=-15
             start=-15
         ), '-', label=scheme.upper())
         ), '-', label=scheme.upper())
 
 
     plt.yscale('log')
     plt.yscale('log')
-    plt.gca().invert_xaxis()
+    # plt.gca().invert_xaxis()
     plt.grid()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
     plt.ylabel('BER')
     plt.legend()
     plt.legend()
     plt.show()
     plt.show()