main.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.metrics import accuracy_score
  4. from models import basic
  5. from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
  6. import misc
  7. from models.autoencoder import Autoencoder, view_encoder
  8. from models.optical_channel import OpticalChannel
  9. def show_constellation(mod, chan, demod, samples=1000):
  10. x = misc.generate_random_bit_array(samples)
  11. x_mod = mod.forward(x)
  12. x_chan = chan.forward(x_mod)
  13. x_demod = demod.forward(x_chan)
  14. x_mod_rect = misc.polar2rect(x_mod)
  15. x_chan_rect = misc.polar2rect(x_chan)
  16. plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
  17. plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
  18. plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
  19. axes = plt.gca()
  20. axes.set_xlim([-2, +2])
  21. axes.set_ylim([-2, +2])
  22. plt.grid()
  23. plt.show()
  24. print('Accuracy : ' + str())
  25. def get_ber(mod, chan, demod, samples=1000):
  26. if samples % mod.N:
  27. samples += mod.N - (samples % mod.N)
  28. x = misc.generate_random_bit_array(samples)
  29. x_mod = mod.forward(x)
  30. x_chan = chan.forward(x_mod)
  31. x_demod = demod.forward(x_chan)
  32. return 1 - accuracy_score(x, x_demod)
  33. def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
  34. ber_x = np.linspace(start, stop, steps)
  35. ber_y = []
  36. for noise in ber_x:
  37. ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
  38. return ber_x, ber_y
  39. def get_Optical_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30, length=100, pulse_shape='rect'):
  40. ber_x = np.linspace(start, stop, steps)
  41. ber_y = []
  42. for noise in ber_x:
  43. tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
  44. length=length, pulse_shape=pulse_shape, sqrt_out=True)
  45. ber_y.append(get_ber(mod, tx_channel, demod, samples=samples))
  46. return ber_x, ber_y
  47. if __name__ == '__main__':
  48. # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
  49. # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
  50. # mod = MaryMod('8psk', 10e6)
  51. # misc.display_alphabet(mod.alphabet, a_vals=True)
  52. # mod = MaryMod('qpsk', 10e6)
  53. # misc.display_alphabet(mod.alphabet, a_vals=True)
  54. # mod = MaryMod('16qam', 10e6)
  55. # misc.display_alphabet(mod.alphabet, a_vals=True)
  56. # mod = MaryMod('64qam', 10e6)
  57. # misc.display_alphabet(mod.alphabet, a_vals=True)
  58. # aenc = Autoencoder(4, -25)
  59. # aenc.train(samples=5e5)
  60. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  61. # label='AE 4bit -25dB')
  62. # aenc = Autoencoder(5, -25)
  63. # aenc.train(samples=2e5)
  64. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  65. # label='AE 5bit -25dB')
  66. # view_encoder(aenc.encoder, 5)
  67. # plt.plot(*get_AWGN_ber(AlphabetMod('32qam', 10e6), AlphabetDemod('32qam', 10e6), samples=12000, start=-15), '-',
  68. # label='32-QAM')
  69. # show_constellation(AlphabetMod('32qam', 10e6), AWGNChannel(-1), AlphabetDemod('32qam', 10e6))
  70. # mod = AlphabetMod('32qam', 10e6)
  71. # misc.display_alphabet(mod.alphabet, a_vals=True)
  72. # pass
  73. # aenc = Autoencoder(5, -15)
  74. # aenc.train(samples=2e6)
  75. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  76. # label='AE 5bit -15dB')
  77. #
  78. # aenc = Autoencoder(4, -25)
  79. # aenc.train(samples=6e5)
  80. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  81. # label='AE 4bit -20dB')
  82. #
  83. # aenc = Autoencoder(4, -15)
  84. # aenc.train(samples=6e5)
  85. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  86. # label='AE 4bit -15dB')
  87. # aenc = Autoencoder(2, -20)
  88. # aenc.train(samples=6e5)
  89. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  90. # label='AE 2bit -20dB')
  91. #
  92. # aenc = Autoencoder(2, -15)
  93. # aenc.train(samples=6e5)
  94. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  95. # label='AE 2bit -15dB')
  96. # aenc = Autoencoder(4, -10)
  97. # aenc.train(samples=5e5)
  98. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  99. # label='AE 4bit -10dB')
  100. #
  101. # aenc = Autoencoder(4, -8)
  102. # aenc.train(samples=5e5)
  103. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  104. # label='AE 4bit -8dB')
  105. # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  106. # plt.plot(*get_AWGN_ber(
  107. # AlphabetMod(scheme, 10e6),
  108. # AlphabetDemod(scheme, 10e6),
  109. # samples=20e3,
  110. # steps=40,
  111. # start=-15
  112. # ), '-', label=scheme.upper())
  113. for l in np.logspace(start=0, stop=3, num=5):
  114. plt.plot(*get_Optical_ber(
  115. AlphabetMod('4pam', 10e6),
  116. AlphabetDemod('4pam', 10e6),
  117. samples=1000,
  118. steps=40,
  119. start=-15,
  120. length=l,
  121. pulse_shape='rcos'
  122. ), '-', label=(str(int(l))+'km'))
  123. plt.yscale('log')
  124. plt.gca().invert_xaxis()
  125. plt.grid()
  126. plt.xlabel('Noise dB')
  127. plt.ylabel('BER')
  128. plt.title("BER against Fiber length")
  129. plt.legend()
  130. plt.show()
  131. for ps in ['rect', 'rcos', 'rrcos']:
  132. plt.plot(*get_Optical_ber(
  133. AlphabetMod('4pam', 10e6),
  134. AlphabetDemod('4pam', 10e6),
  135. samples=1000,
  136. steps=40,
  137. start=-15,
  138. length=10,
  139. pulse_shape=ps
  140. ), '-', label=ps)
  141. plt.yscale('log')
  142. plt.gca().invert_xaxis()
  143. plt.grid()
  144. plt.xlabel('Noise dB')
  145. plt.ylabel('BER')
  146. plt.title("BER for different pulse shapes")
  147. plt.legend()
  148. plt.show()
  149. pass