main.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.metrics import accuracy_score
  4. from models import basic
  5. from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
  6. import misc
  7. from models.autoencoder import Autoencoder, view_encoder
  8. def show_constellation(mod, chan, demod, samples=1000):
  9. x = misc.generate_random_bit_array(samples)
  10. x_mod = mod.forward(x)
  11. x_chan = chan.forward(x_mod)
  12. x_demod = demod.forward(x_chan)
  13. x_mod_rect = misc.polar2rect(x_mod)
  14. x_chan_rect = misc.polar2rect(x_chan)
  15. plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
  16. plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
  17. plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
  18. axes = plt.gca()
  19. axes.set_xlim([-2, +2])
  20. axes.set_ylim([-2, +2])
  21. plt.grid()
  22. plt.show()
  23. print('Accuracy : ' + str())
  24. def get_ber(mod, chan, demod, samples=1000):
  25. if samples % mod.N:
  26. samples += mod.N - (samples % mod.N)
  27. x = misc.generate_random_bit_array(samples)
  28. x_mod = mod.forward(x)
  29. x_chan = chan.forward(x_mod)
  30. x_demod = demod.forward(x_chan)
  31. return 1 - accuracy_score(x, x_demod)
  32. def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
  33. ber_x = np.linspace(start, stop, steps)
  34. ber_y = []
  35. for noise in ber_x:
  36. ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
  37. return ber_x, ber_y
  38. if __name__ == '__main__':
  39. # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
  40. # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
  41. # mod = MaryMod('8psk', 10e6)
  42. # misc.display_alphabet(mod.alphabet, a_vals=True)
  43. # mod = MaryMod('qpsk', 10e6)
  44. # misc.display_alphabet(mod.alphabet, a_vals=True)
  45. # mod = MaryMod('16qam', 10e6)
  46. # misc.display_alphabet(mod.alphabet, a_vals=True)
  47. # mod = MaryMod('64qam', 10e6)
  48. # misc.display_alphabet(mod.alphabet, a_vals=True)
  49. # aenc = Autoencoder(4, -25)
  50. # aenc.train(samples=5e5)
  51. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  52. # label='AE 4bit -25dB')
  53. # aenc = Autoencoder(5, -25)
  54. # aenc.train(samples=2e5)
  55. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  56. # label='AE 5bit -25dB')
  57. # view_encoder(aenc.encoder, 5)
  58. # plt.plot(*get_AWGN_ber(AlphabetMod('32qam', 10e6), AlphabetDemod('32qam', 10e6), samples=12000, start=-15), '-',
  59. # label='32-QAM')
  60. # show_constellation(AlphabetMod('32qam', 10e6), AWGNChannel(-1), AlphabetDemod('32qam', 10e6))
  61. # mod = AlphabetMod('32qam', 10e6)
  62. # misc.display_alphabet(mod.alphabet, a_vals=True)
  63. # pass
  64. # aenc = Autoencoder(5, -15)
  65. # aenc.train(samples=2e6)
  66. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  67. # label='AE 5bit -15dB')
  68. #
  69. # aenc = Autoencoder(4, -25)
  70. # aenc.train(samples=6e5)
  71. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  72. # label='AE 4bit -20dB')
  73. #
  74. # aenc = Autoencoder(4, -15)
  75. # aenc.train(samples=6e5)
  76. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  77. # label='AE 4bit -15dB')
  78. # aenc = Autoencoder(2, -20)
  79. # aenc.train(samples=6e5)
  80. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  81. # label='AE 2bit -20dB')
  82. #
  83. # aenc = Autoencoder(2, -15)
  84. # aenc.train(samples=6e5)
  85. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  86. # label='AE 2bit -15dB')
  87. # aenc = Autoencoder(4, -10)
  88. # aenc.train(samples=5e5)
  89. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  90. # label='AE 4bit -10dB')
  91. #
  92. # aenc = Autoencoder(4, -8)
  93. # aenc.train(samples=5e5)
  94. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  95. # label='AE 4bit -8dB')
  96. for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  97. plt.plot(*get_AWGN_ber(
  98. AlphabetMod(scheme, 10e6),
  99. AlphabetDemod(scheme, 10e6),
  100. samples=20e3,
  101. steps=40,
  102. start=-15
  103. ), '-', label=scheme.upper())
  104. plt.yscale('log')
  105. plt.gca().invert_xaxis()
  106. plt.grid()
  107. plt.xlabel('Noise dB')
  108. plt.ylabel('BER')
  109. plt.legend()
  110. plt.show()
  111. pass