瀏覽代碼

Merge remote-tracking branch 'origin/swarm_optimisation'

Min 5 年之前
父節點
當前提交
b7fdda8b86
共有 1 個文件被更改,包括 15 次插入4 次删除
  1. 15 4
      main.py

+ 15 - 4
main.py

@@ -5,6 +5,7 @@ from sklearn.metrics import accuracy_score
 from models import basic
 from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
 import misc
+import math
 from models.autoencoder import Autoencoder, view_encoder
 from models.autoencoder import Autoencoder, view_encoder
 
 
 
 
@@ -45,6 +46,16 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
     return ber_x, ber_y
 
 
 
 
+def get_SNR(mod, demod, samples=1000, start=-8, stop=5, steps=30):
+    ber_x, ber_y = get_AWGN_ber(mod, demod, samples, start, stop, steps)
+    x_mod = mod.forward(misc.generate_random_bit_array(samples*mod.N))
+    sig_amp = x_mod[:, 0]
+    sig_power = [A ** 2 for A in sig_amp]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+    SNR = (ber_x * -1) + av_sig_pow
+    return SNR, ber_y
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
 
@@ -111,18 +122,18 @@ if __name__ == '__main__':
     #          label='AE 4bit -8dB')
     #          label='AE 4bit -8dB')
 
 
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
+        plt.plot(*get_SNR(
             AlphabetMod(scheme, 10e6),
             AlphabetMod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+            samples=100e3,
             steps=40,
             steps=40,
             start=-15
             start=-15
         ), '-', label=scheme.upper())
         ), '-', label=scheme.upper())
 
 
     plt.yscale('log')
     plt.yscale('log')
-    plt.gca().invert_xaxis()
+    # plt.gca().invert_xaxis()
     plt.grid()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
     plt.ylabel('BER')
     plt.legend()
     plt.legend()
     plt.show()
     plt.show()