main.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.metrics import accuracy_score
  4. from models import basic
  5. from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
  6. import misc
  7. def show_constellation(mod, chan, demod, samples=1000):
  8. x = misc.generate_random_bit_array(samples)
  9. x_mod = mod.forward(x)
  10. x_chan = chan.forward(x_mod)
  11. x_demod = demod.forward(x_chan)
  12. x_mod_rect = misc.polar2rect(x_mod)
  13. x_chan_rect = misc.polar2rect(x_chan)
  14. plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
  15. plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
  16. plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
  17. axes = plt.gca()
  18. axes.set_xlim([-2, +2])
  19. axes.set_ylim([-2, +2])
  20. plt.grid()
  21. plt.show()
  22. print('Accuracy : ' + str())
  23. def get_ber(mod, chan, demod, samples=1000):
  24. x = misc.generate_random_bit_array(samples)
  25. x_mod = mod.forward(x)
  26. x_chan = chan.forward(x_mod)
  27. x_demod = demod.forward(x_chan)
  28. return 1 - accuracy_score(x, x_demod)
  29. def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
  30. ber_x = np.linspace(start, stop, steps)
  31. ber_y = []
  32. for noise in ber_x:
  33. ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
  34. return ber_x, ber_y
  35. if __name__ == '__main__':
  36. # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
  37. # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
  38. # mod = MaryMod('8psk', 10e6)
  39. # misc.display_alphabet(mod.alphabet, a_vals=True)
  40. # mod = MaryMod('qpsk', 10e6)
  41. # misc.display_alphabet(mod.alphabet, a_vals=True)
  42. # mod = MaryMod('16qam', 10e6)
  43. # misc.display_alphabet(mod.alphabet, a_vals=True)
  44. # mod = MaryMod('64qam', 10e6)
  45. # misc.display_alphabet(mod.alphabet, a_vals=True)
  46. plt.plot(*get_AWGN_ber(AlphabetMod('64qam', 10e6), AlphabetDemod('64qam', 10e6), samples=12000, start=-15), '-', label='64-QAM')
  47. plt.plot(*get_AWGN_ber(AlphabetMod('16qam', 10e6), AlphabetDemod('16qam', 10e6), samples=12000, start=-15), '-', label='16-QAM')
  48. plt.plot(*get_AWGN_ber(AlphabetMod('qpsk', 10e6), AlphabetDemod('qpsk', 10e6), samples=12000, start=-15), '-', label='QPSK')
  49. plt.plot(*get_AWGN_ber(AlphabetMod('8psk', 10e6), AlphabetDemod('8psk', 10e6), samples=12000, start=-15), '-', label='8PSK')
  50. plt.plot(*get_AWGN_ber(BPSKMod(10e6), BPSKDemod(10e6, 10e3), samples=12000), '-', label='BPSK')
  51. plt.yscale('log')
  52. plt.gca().invert_xaxis()
  53. plt.grid()
  54. plt.xlabel('Noise dB')
  55. plt.ylabel('BER')
  56. plt.legend()
  57. plt.show()
  58. pass