Ver código fonte

Function for SNR

Oliver Jaison 5 anos atrás
pai
commit
9955b46ead
1 arquivos alterados com 15 adições e 4 exclusões
  1. 15 4
      main.py

+ 15 - 4
main.py

@@ -5,6 +5,7 @@ from sklearn.metrics import accuracy_score
 from models import basic
 from models import basic
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
 import misc
 import misc
+import math
 from models.autoencoder import Autoencoder, view_encoder
 from models.autoencoder import Autoencoder, view_encoder
 
 
 
 
@@ -45,6 +46,16 @@ def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
     return ber_x, ber_y
     return ber_x, ber_y
 
 
 
 
+def get_SNR(mod, demod, samples=1000, start=-8, stop=5, steps=30):
+    ber_x, ber_y = get_AWGN_ber(mod, demod, samples, start, stop, steps)
+    x_mod = mod.forward(misc.generate_random_bit_array(samples*mod.N))
+    sig_amp = x_mod[:, 0]
+    sig_power = [A ** 2 for A in sig_amp]
+    av_sig_pow = np.mean(sig_power)
+    av_sig_pow = math.log(av_sig_pow, 10)
+    SNR = (ber_x * -1) + av_sig_pow
+    return SNR, ber_y
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
     # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
 
 
@@ -111,18 +122,18 @@ if __name__ == '__main__':
     #          label='AE 4bit -8dB')
     #          label='AE 4bit -8dB')
 
 
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
     for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
-        plt.plot(*get_AWGN_ber(
+        plt.plot(*get_SNR(
             AlphabetMod(scheme, 10e6),
             AlphabetMod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
             AlphabetDemod(scheme, 10e6),
-            samples=20e3,
+            samples=100e3,
             steps=40,
             steps=40,
             start=-15
             start=-15
         ), '-', label=scheme.upper())
         ), '-', label=scheme.upper())
 
 
     plt.yscale('log')
     plt.yscale('log')
-    plt.gca().invert_xaxis()
+    # plt.gca().invert_xaxis()
     plt.grid()
     plt.grid()
-    plt.xlabel('Noise dB')
+    plt.xlabel('SNR dB')
     plt.ylabel('BER')
     plt.ylabel('BER')
     plt.legend()
     plt.legend()
     plt.show()
     plt.show()