main.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from sklearn.metrics import accuracy_score
  4. from models import basic
  5. from models.basic import AWGNChannel, BPSKDemod, BPSKMod, BypassChannel, AlphabetMod, AlphabetDemod
  6. import misc
  7. import math
  8. import os
  9. from models.autoencoder import Autoencoder, view_encoder
  10. from models.optical_channel import OpticalChannel
  11. from multiprocessing import Pool
  12. CPU_COUNT = os.environ.get("CPU_COUNT", os.cpu_count())
  13. def show_constellation(mod, chan, demod, samples=1000):
  14. x = misc.generate_random_bit_array(samples)
  15. x_mod = mod.forward(x)
  16. x_chan = chan.forward(x_mod)
  17. x_demod = demod.forward(x_chan)
  18. x_mod_rect = misc.polar2rect(x_mod)
  19. x_chan_rect = misc.polar2rect(x_chan)
  20. plt.plot(x_chan_rect[:, 0][x], x_chan_rect[:, 1][x], '+')
  21. plt.plot(x_chan_rect[:, 0][~x], x_chan_rect[:, 1][~x], '+')
  22. plt.plot(x_mod_rect[:, 0], x_mod_rect[:, 1], 'ro')
  23. axes = plt.gca()
  24. axes.set_xlim([-2, +2])
  25. axes.set_ylim([-2, +2])
  26. plt.grid()
  27. plt.show()
  28. print('Accuracy : ' + str())
  29. def get_ber(mod, chan, demod, samples=1000):
  30. if samples % mod.N:
  31. samples += mod.N - (samples % mod.N)
  32. x = misc.generate_random_bit_array(samples)
  33. x_mod = mod.forward(x)
  34. x_chan = chan.forward(x_mod)
  35. x_demod = demod.forward(x_chan)
  36. return 1 - accuracy_score(x, x_demod)
  37. def get_AWGN_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30):
  38. ber_x = np.linspace(start, stop, steps)
  39. ber_y = []
  40. for noise in ber_x:
  41. ber_y.append(get_ber(mod, AWGNChannel(noise), demod, samples=samples))
  42. return ber_x, ber_y
  43. def __calc_ber(packed):
  44. # This function has to be outside get_Optical_ber in order to be pickled by pool
  45. mod, demod, noise, length, pulse_shape, samples = packed
  46. tx_channel = OpticalChannel(noise_level=noise, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9,
  47. length=length, pulse_shape=pulse_shape, sqrt_out=True)
  48. return get_ber(mod, tx_channel, demod, samples=samples)
  49. def get_Optical_ber(mod, demod, samples=1000, start=-8., stop=5., steps=30, length=100, pulse_shape='rect'):
  50. ber_x = np.linspace(start, stop, steps)
  51. ber_y = []
  52. print(f"Computing Optical BER.. 0/{len(ber_x)}", end='')
  53. with Pool(CPU_COUNT) as pool:
  54. packed_args = [(mod, demod, noise, length, pulse_shape, samples) for noise in ber_x]
  55. for i, ber in enumerate(pool.imap(__calc_ber, packed_args)):
  56. ber_y.append(ber)
  57. i += 1 # just offset by 1
  58. print(f"\rComputing Optical BER.. {i}/{len(ber_x)} ({i*100/len(ber_x):6.2f}%)", end='')
  59. print()
  60. return ber_x, ber_y
  61. def get_SNR(mod, demod, ber_func=get_Optical_ber, samples=1000, start=-5, stop=15, **ber_kwargs):
  62. """
  63. SNR for optics and RF should be calculated the same, that is A^2
  64. Because P∝V² and P∝I²
  65. """
  66. x_mod = mod.forward(misc.generate_random_bit_array(samples * mod.N))
  67. sig_amp = x_mod[:, 0]
  68. sig_power = [A ** 2 for A in sig_amp]
  69. av_sig_pow = np.mean(sig_power)
  70. av_sig_pow = math.log(av_sig_pow, 10)
  71. noise_start = -start + av_sig_pow
  72. noise_stop = -stop + av_sig_pow
  73. ber_x, ber_y = ber_func(mod, demod, samples, noise_start, noise_stop, **ber_kwargs)
  74. SNR = -ber_x + av_sig_pow
  75. return SNR, ber_y
  76. if __name__ == '__main__':
  77. # show_constellation(BPSKMod(10e6), AWGNChannel(-1), BPSKDemod(10e6, 10e3))
  78. # get_ber(BPSKMod(10e6), AWGNChannel(-20), BPSKDemod(10e6, 10e3))
  79. # mod = MaryMod('8psk', 10e6)
  80. # misc.display_alphabet(mod.alphabet, a_vals=True)
  81. # mod = MaryMod('qpsk', 10e6)
  82. # misc.display_alphabet(mod.alphabet, a_vals=True)
  83. # mod = MaryMod('16qam', 10e6)
  84. # misc.display_alphabet(mod.alphabet, a_vals=True)
  85. # mod = MaryMod('64qam', 10e6)
  86. # misc.display_alphabet(mod.alphabet, a_vals=True)
  87. # aenc = Autoencoder(4, -25)
  88. # aenc.train(samples=5e5)
  89. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  90. # label='AE 4bit -25dB')
  91. # aenc = Autoencoder(5, -25)
  92. # aenc.train(samples=2e5)
  93. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  94. # label='AE 5bit -25dB')
  95. # view_encoder(aenc.encoder, 5)
  96. # plt.plot(*get_AWGN_ber(AlphabetMod('32qam', 10e6), AlphabetDemod('32qam', 10e6), samples=12000, start=-15), '-',
  97. # label='32-QAM')
  98. # show_constellation(AlphabetMod('32qam', 10e6), AWGNChannel(-1), AlphabetDemod('32qam', 10e6))
  99. # mod = AlphabetMod('32qam', 10e6)
  100. # misc.display_alphabet(mod.alphabet, a_vals=True)
  101. # pass
  102. # aenc = Autoencoder(5, -15)
  103. # aenc.train(samples=2e6)
  104. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  105. # label='AE 5bit -15dB')
  106. #
  107. # aenc = Autoencoder(4, -25)
  108. # aenc.train(samples=6e5)
  109. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  110. # label='AE 4bit -20dB')
  111. #
  112. # aenc = Autoencoder(4, -15)
  113. # aenc.train(samples=6e5)
  114. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  115. # label='AE 4bit -15dB')
  116. # aenc = Autoencoder(2, -20)
  117. # aenc.train(samples=6e5)
  118. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  119. # label='AE 2bit -20dB')
  120. #
  121. # aenc = Autoencoder(2, -15)
  122. # aenc.train(samples=6e5)
  123. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  124. # label='AE 2bit -15dB')
  125. # aenc = Autoencoder(4, -10)
  126. # aenc.train(samples=5e5)
  127. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  128. # label='AE 4bit -10dB')
  129. #
  130. # aenc = Autoencoder(4, -8)
  131. # aenc.train(samples=5e5)
  132. # plt.plot(*get_AWGN_ber(aenc.get_modulator(), aenc.get_demodulator(), samples=12000, start=-15), '-',
  133. # label='AE 4bit -8dB')
  134. # for scheme in ['64qam', '32qam', '16qam', 'qpsk', '8psk']:
  135. # plt.plot(*get_SNR(
  136. # AlphabetMod(scheme, 10e6),
  137. # AlphabetDemod(scheme, 10e6),
  138. # samples=100e3,
  139. # steps=40,
  140. # start=-15
  141. # ), '-', label=scheme.upper())
  142. # plt.yscale('log')
  143. # plt.grid()
  144. # plt.xlabel('SNR dB')
  145. # plt.ylabel('BER')
  146. # plt.legend()
  147. # plt.show()
  148. for l in np.logspace(start=0, stop=3, num=5):
  149. plt.plot(*get_SNR(
  150. AlphabetMod('4pam', 10e6),
  151. AlphabetDemod('4pam', 10e6),
  152. samples=2000,
  153. steps=200,
  154. start=-5,
  155. stop=20,
  156. length=l,
  157. pulse_shape='rcos'
  158. ), '-', label=(str(int(l))+'km'))
  159. plt.yscale('log')
  160. # plt.gca().invert_xaxis()
  161. plt.grid()
  162. plt.xlabel('SNR dB')
  163. # plt.ylabel('BER')
  164. plt.title("BER against Fiber length")
  165. plt.legend()
  166. plt.show()
  167. # FIXME: Exit for now
  168. exit()
  169. for ps in ['rect']: #, 'rcos', 'rrcos']:
  170. plt.plot(*get_Optical_ber(
  171. AlphabetMod('4pam', 10e6),
  172. AlphabetDemod('4pam', 10e6),
  173. samples=30000,
  174. steps=40,
  175. start=-35,
  176. # stop=10,
  177. length=1,
  178. pulse_shape=ps
  179. ), '-', label=ps)
  180. plt.yscale('log')
  181. plt.grid()
  182. plt.xlabel('SNR dB')
  183. plt.ylabel('BER')
  184. plt.title("BER for different pulse shapes")
  185. plt.legend()
  186. plt.show()
  187. pass