main.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.metrics import accuracy_score
  4. from models import basic
  5. from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
  6. import misc
  7. import math
  8. from models.autoencoder import Autoencoder, view_encoder
  9. from models.optical_channel import OpticalChannel
  10. def show_constellation(mod, chan, demod, samples=1000):
  11. x = misc.generate_random_bit_array(samples)
  12. x_mod = mod.forward(x)
  13. x_chan = chan.forward(x_mod)
  14. x_demod = demod.forward(x_chan)
  15. x_mod_rect = misc.polar2rect(x_mod)
  16. x_chan_rect = misc.polar2rect(x_chan)
  17. plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
  18. plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
  19. plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
  20. axes = plt.gca()
  21. axes.set_xlim([-2, +2])
  22. axes.set_ylim([-2, +2])
  23. plt.grid()
  24. plt.show()
  25. print('Accuracy : ' + str())
  26. def get_ber(mod, chan, demod, samples=1000):
  27. if samples % mod.N:
  28. samples += mod.N - (samples % mod.N)
  29. x = misc.generate_random_bit_array(samples)
  30. x_mod = mod.forward(x)
  31. x_chan = chan.forward(x_mod)
  32. x_demod = demod.forward(x_chan)
  33. return 1 - accuracy_score(x, x_demod)
  34. def get_AWGN_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30):
  35. ber_x = np.linspace(start, stop, steps)
  36. ber_y = []
  37. for noise in ber_x:
  38. ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
  39. return ber_x, ber_y
  40. def get_SNR(mod, demod, samples=1000, start=-8, stop=5, steps=30):
  41. ber_x, ber_y = get_AWGN_ber(mod, demod, samples, start, stop, steps)
  42. x_mod = mod.forward(misc.generate_random_bit_array(samples*mod.N))
  43. sig_amp = x_mod[:, 0]
  44. sig_power = [A ** 2 for A in sig_amp]
  45. av_sig_pow = np.mean(sig_power)
  46. av_sig_pow = math.log(av_sig_pow, 10)
  47. SNR = (ber_x * -1) + av_sig_pow
  48. return SNR, ber_y
  49. def get_Optical_ber(mod, demod, samples=1000, start=-8, stop=5, steps=30, length=100, pulse_shape='rect'):
  50. ber_x = np.linspace(start, stop, steps)
  51. ber_y = []
  52. for noise in ber_x:
  53. tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
  54. length=length, pulse_shape=pulse_shape, sqrt_out=True)
  55. ber_y.append(get_ber(mod, tx_channel, demod, samples=samples))
  56. return ber_x, ber_y
  57. if __name__ == '__main__':
  58. # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
  59. # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
  60. # mod = MaryMod('8psk', 10e6)
  61. # misc.display_alphabet(mod.alphabet, a_vals=True)
  62. # mod = MaryMod('qpsk', 10e6)
  63. # misc.display_alphabet(mod.alphabet, a_vals=True)
  64. # mod = MaryMod('16qam', 10e6)
  65. # misc.display_alphabet(mod.alphabet, a_vals=True)
  66. # mod = MaryMod('64qam', 10e6)
  67. # misc.display_alphabet(mod.alphabet, a_vals=True)
  68. # aenc = Autoencoder(4, -25)
  69. # aenc.train(samples=5e5)
  70. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  71. # label='AE 4bit -25dB')
  72. # aenc = Autoencoder(5, -25)
  73. # aenc.train(samples=2e5)
  74. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  75. # label='AE 5bit -25dB')
  76. # view_encoder(aenc.encoder, 5)
  77. # plt.plot(*get_AWGN_ber(AlphabetMod('32qam', 10e6), AlphabetDemod('32qam', 10e6), samples=12000, start=-15), '-',
  78. # label='32-QAM')
  79. # show_constellation(AlphabetMod('32qam', 10e6), AWGNChannel(-1), AlphabetDemod('32qam', 10e6))
  80. # mod = AlphabetMod('32qam', 10e6)
  81. # misc.display_alphabet(mod.alphabet, a_vals=True)
  82. # pass
  83. # aenc = Autoencoder(5, -15)
  84. # aenc.train(samples=2e6)
  85. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  86. # label='AE 5bit -15dB')
  87. #
  88. # aenc = Autoencoder(4, -25)
  89. # aenc.train(samples=6e5)
  90. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  91. # label='AE 4bit -20dB')
  92. #
  93. # aenc = Autoencoder(4, -15)
  94. # aenc.train(samples=6e5)
  95. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  96. # label='AE 4bit -15dB')
  97. # aenc = Autoencoder(2, -20)
  98. # aenc.train(samples=6e5)
  99. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  100. # label='AE 2bit -20dB')
  101. #
  102. # aenc = Autoencoder(2, -15)
  103. # aenc.train(samples=6e5)
  104. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  105. # label='AE 2bit -15dB')
  106. # aenc = Autoencoder(4, -10)
  107. # aenc.train(samples=5e5)
  108. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  109. # label='AE 4bit -10dB')
  110. #
  111. # aenc = Autoencoder(4, -8)
  112. # aenc.train(samples=5e5)
  113. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  114. # label='AE 4bit -8dB')
  115. # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  116. # plt.plot(*get_SNR(
  117. # AlphabetMod(scheme, 10e6),
  118. # AlphabetDemod(scheme, 10e6),
  119. # samples=100e3,
  120. # steps=40,
  121. # start=-15
  122. # ), '-', label=scheme.upper())
  123. # plt.yscale('log')
  124. # plt.grid()
  125. # plt.xlabel('SNR dB')
  126. # plt.ylabel('BER')
  127. # plt.legend()
  128. # plt.show()
  129. for l in np.logspace(start=0, stop=3, num=5):
  130. plt.plot(*get_Optical_ber(
  131. AlphabetMod('4pam', 10e6),
  132. AlphabetDemod('4pam', 10e6),
  133. samples=1000,
  134. steps=40,
  135. start=-15,
  136. length=l,
  137. pulse_shape='rcos'
  138. ), '-', label=(str(int(l))+'km'))
  139. plt.yscale('log')
  140. plt.gca().invert_xaxis()
  141. plt.grid()
  142. plt.xlabel('Noise dB')
  143. plt.ylabel('BER')
  144. plt.title("BER against Fiber length")
  145. plt.legend()
  146. plt.show()
  147. for ps in ['rect', 'rcos', 'rrcos']:
  148. plt.plot(*get_Optical_ber(
  149. AlphabetMod('4pam', 10e6),
  150. AlphabetDemod('4pam', 10e6),
  151. samples=1000,
  152. steps=40,
  153. start=-15,
  154. length=10,
  155. pulse_shape=ps
  156. ), '-', label=ps)
  157. plt.yscale('log')
  158. plt.gca().invert_xaxis()
  159. plt.grid()
  160. plt.xlabel('Noise dB')
  161. plt.ylabel('BER')
  162. plt.title("BER for different pulse shapes")
  163. plt.legend()
  164. plt.show()
  165. pass