OFDM.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.keras import layers, losses
  4. from matplotlib import pyplot as plt
  5. from models.basic import AlphabetMod, AlphabetDemod, RFSignal
  6. from models.custom_layers import OpticalChannel
  7. class OFDM():
  8. def __init__(self, num_channels, channel_args, **kwargs):
  9. self.num_channels = num_channels
  10. self.channel_args = channel_args
  11. if "dc_offset" in kwargs:
  12. self.dc_offset = kwargs["dc_offset"]
  13. else:
  14. self.dc_offset = 50
  15. self.mod_schemes = None
  16. self.modulators = None
  17. self.demodulators = None
  18. self.bits_per_ofdm_symbol = None
  19. self.set_mod_schemes()
  20. self.tx_bits = None
  21. def set_mod_schemes(self, new_mod_schemes=None):
  22. if new_mod_schemes is None:
  23. self.mod_schemes = ["qpsk" for x in range(self.num_channels)]
  24. else:
  25. self.mod_schemes = new_mod_schemes
  26. self.modulators = {x: AlphabetMod(x, 0.0) for x in list(set(self.mod_schemes))}
  27. self.demodulators = {x: AlphabetDemod(x, 0.0) for x in list(set(self.mod_schemes))}
  28. self.bits_per_ofdm_symbol = 0
  29. for mod in self.mod_schemes:
  30. self.bits_per_ofdm_symbol += self.modulators[mod].N
  31. def transmit_and_equalize(self, tx_stream):
  32. optical_channel = OpticalChannel(**self.channel_args, num_of_samples=len(tx_stream))
  33. rx_stream = optical_channel(tx_stream)
  34. rx_stream_f = np.fft.fft(np.sqrt(rx_stream))
  35. f = np.fft.fftfreq(rx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
  36. # TODO : Equalization needs to be tested
  37. mul = np.exp(-0.5j*(self.channel_args["dispersion_factor"])*(self.channel_args["fiber_length"])*np.square(2*np.pi*f))
  38. rx_compensated_f = rx_stream_f * mul
  39. rx_compensated = np.abs(np.fft.ifft(rx_compensated_f))
  40. return rx_compensated
  41. def find_optimal_schemes(self):
  42. # TODO : Choosing approapriate modulation scheme after a single pass of QPSK
  43. pass
  44. def encode_ofdm_symbol(self, bits):
  45. if bits.size != self.bits_per_ofdm_symbol:
  46. raise Exception("Number of bits passed does not match number of bits per OFDM symbol!")
  47. bit_idx = 0
  48. enc_stream_upper_f = []
  49. for mod in self.mod_schemes:
  50. num_bits = self.modulators[mod].N
  51. a = bits[bit_idx:bit_idx + num_bits]
  52. alph_sym = self.modulators[mod].forward(bits[bit_idx:bit_idx + num_bits]).rect[0]
  53. bit_idx += num_bits
  54. sym = alph_sym[0] + 1j * alph_sym[1]
  55. enc_stream_upper_f.append(sym)
  56. if len(enc_stream_upper_f) != self.num_channels:
  57. raise Exception("An error occurred! Computed upper sideband contains more channels than specified")
  58. enc_stream_upper_f = np.asarray(enc_stream_upper_f)
  59. enc_stream_lower_f = np.conjugate(np.flip(enc_stream_upper_f))
  60. enc_stream_f = np.concatenate((self.dc_offset, enc_stream_upper_f, enc_stream_lower_f), axis=None)
  61. enc_stream_t = np.fft.ifft(enc_stream_f)
  62. return np.real(enc_stream_t)
  63. def encode_bits(self, bit_stream):
  64. if bit_stream.size % self.bits_per_ofdm_symbol != 0:
  65. raise Exception("Bit stream size is not an integer multiple of bits per OFDM symbol!")
  66. self.tx_bits = bit_stream
  67. bit_blocks = bit_stream.reshape(-1, self.bits_per_ofdm_symbol)
  68. tx_stream = []
  69. for bits in bit_blocks:
  70. tx_stream.append(self.encode_ofdm_symbol(bits))
  71. tx_stream = np.asarray(tx_stream).flatten()
  72. return tx_stream
  73. def decode_ofdm_symbol(self, stream):
  74. if stream.size != (2*self.num_channels + 1):
  75. raise Exception("Number of samples passed is not equal to number of samples in one OFDM symbol!")
  76. stream_f = np.fft.fft(stream)
  77. stream_upper_f = stream_f[1:self.num_channels+1]
  78. bits = []
  79. for chan, mod in zip(stream_upper_f, self.mod_schemes):
  80. rf_sym = RFSignal(np.zeros((3, 3)))
  81. rf_sym.set_rect_xy(np.real(chan), np.imag(chan))
  82. bits.append(self.demodulators[mod].forward(rf_sym).astype(int))
  83. bits = np.asarray(bits).flatten()
  84. return bits
  85. def decode_bits(self, rx_stream):
  86. if rx_stream.size % (2*self.num_channels + 1) != 0:
  87. raise Exception("Number of samples passed is not an integer multiple of the samples per OFDM symbol")
  88. pass
  89. ofdm_symbols = rx_stream.reshape(-1, (2*self.num_channels + 1))
  90. rx_bit_stream = []
  91. for sym in ofdm_symbols:
  92. rx_bit_stream.append(self.decode_ofdm_symbol(sym))
  93. rx_bit_stream = np.asarray(rx_bit_stream).flatten()
  94. return rx_bit_stream
  95. def computeBER(self, n_ofdm_symbols):
  96. bit_stream = np.random.randint(2, size=n_ofdm_symbols * self.bits_per_ofdm_symbol)
  97. tx_stream = self.encode_bits(bit_stream)
  98. rx_stream = self.transmit_and_equalize(tx_stream)
  99. tx_stream_f = np.fft.fft(tx_stream)
  100. rx_stream_f = np.fft.fft(rx_stream)
  101. f = np.fft.fftfreq(tx_stream_f.shape[0], d=1/(self.channel_args["fs"]))
  102. plt.plot(f, np.real(tx_stream_f), 'x')
  103. plt.plot(f, np.real(rx_stream_f), 'x')
  104. plt.ylim([-5, 5])
  105. plt.show()
  106. # plt.plot(tx_stream)
  107. # plt.plot(rx_stream)
  108. # plt.show()
  109. rx_bit_stream = self.decode_bits(rx_stream)
  110. ber = np.sum(np.equal(bit_stream, rx_bit_stream).astype(int))/bit_stream.size
  111. print(ber)
  112. pass
  113. if __name__ == '__main__':
  114. BANDWIDTH = 100e9
  115. channel_args = {"fs": BANDWIDTH,
  116. "dispersion_factor": -21.7 * 1e-24,
  117. "fiber_length": 10,
  118. "fiber_length_stddev": 0,
  119. "lpf_cutoff": BANDWIDTH / 2,
  120. "rx_stddev": 0.01,
  121. "sig_avg": 0.5,
  122. "enob": 10}
  123. # TODO : For High Cardinality or number of OFDM Symbols the decoder fails and BER ~ 0.5
  124. ofdm_obj = OFDM(64, channel_args)
  125. ofdm_obj.computeBER(100)
  126. pass