OFDM.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy import interpolate
  4. from models.custom_layers import OpticalChannel
  5. SAMPLING_FREQUENCY = 336e9
  6. CARDINALITY = 4
  7. SAMPLES_PER_SYMBOL = 128
  8. MESSAGES_PER_BLOCK = 9
  9. DISPERSION_FACTOR = -21.7 * 1e-24
  10. FIBER_LENGTH = 50
  11. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  12. num_of_samples=80,
  13. dispersion_factor=DISPERSION_FACTOR,
  14. fiber_length=FIBER_LENGTH)
  15. K = 64 # number of OFDM subcarriers
  16. CP = K // 4 # length of the cyclic prefix: 25% of the block
  17. P = 8 # number of pilot carriers per OFDM block
  18. pilotValue = 3 + 3j # The known value each pilot transmits
  19. allCarriers = np.arange(K) # indices of all subcarriers ([0, 1, ... K-1])
  20. pilotCarriers = allCarriers[::K // P] # Pilots is every (K/P)th carrier.
  21. # For convenience of channel estimation, let's make the last carriers also be a pilot
  22. pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])
  23. P = P + 1
  24. # data carriers are all remaining carriers
  25. dataCarriers = np.delete(allCarriers, pilotCarriers)
  26. print("allCarriers: %s" % allCarriers)
  27. print("pilotCarriers: %s" % pilotCarriers)
  28. print("dataCarriers: %s" % dataCarriers)
  29. plt.plot(pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot')
  30. plt.plot(dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data')
  31. mu = 4 # bits per symbol (i.e. 16QAM)
  32. payloadBits_per_OFDM = len(dataCarriers) * mu # number of payload bits per OFDM symbol
  33. mapping_table = {
  34. (0, 0, 0, 0): -3 - 3j,
  35. (0, 0, 0, 1): -3 - 1j,
  36. (0, 0, 1, 0): -3 + 3j,
  37. (0, 0, 1, 1): -3 + 1j,
  38. (0, 1, 0, 0): -1 - 3j,
  39. (0, 1, 0, 1): -1 - 1j,
  40. (0, 1, 1, 0): -1 + 3j,
  41. (0, 1, 1, 1): -1 + 1j,
  42. (1, 0, 0, 0): 3 - 3j,
  43. (1, 0, 0, 1): 3 - 1j,
  44. (1, 0, 1, 0): 3 + 3j,
  45. (1, 0, 1, 1): 3 + 1j,
  46. (1, 1, 0, 0): 1 - 3j,
  47. (1, 1, 0, 1): 1 - 1j,
  48. (1, 1, 1, 0): 1 + 3j,
  49. (1, 1, 1, 1): 1 + 1j
  50. }
  51. mapping_table_dec = {
  52. (0): -3 - 3j,
  53. (1): -3 - 1j,
  54. (2): -3 + 3j,
  55. (3): -3 + 1j,
  56. (4): -1 - 3j,
  57. (5): -1 - 1j,
  58. (6): -1 + 3j,
  59. (7): -1 + 1j,
  60. (8): 3 - 3j,
  61. (9): 3 - 1j,
  62. (10): 3 + 3j,
  63. (11): 3 + 1j,
  64. (12): 1 - 3j,
  65. (13): 1 - 1j,
  66. (14): 1 + 3j,
  67. (15): 1 + 1j
  68. }
  69. for b3 in [0, 1]:
  70. for b2 in [0, 1]:
  71. for b1 in [0, 1]:
  72. for b0 in [0, 1]:
  73. B = (b3, b2, b1, b0)
  74. Q = mapping_table[B]
  75. plt.plot(Q.real, Q.imag, 'bo')
  76. plt.text(Q.real, Q.imag + 0.2, "".join(str(x) for x in B), ha='center')
  77. demapping_table = {v: k for k, v in mapping_table.items()}
  78. # Replace with our channel
  79. channelResponse = np.array([1, 0, 0.3 + 0.3j]) # the impulse response of the wireless channel
  80. H_exact = np.fft.fft(channelResponse, K)
  81. plt.plot(allCarriers, abs(H_exact))
  82. SNRdb = 25 # signal to noise-ratio in dB at the receiver
  83. # Here
  84. # water filling, gradient decent methods for optimising the symbol mapping, instead of 16 QAM
  85. bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM,))
  86. print("Bits count: ", len(bits))
  87. print("First 20 bits: ", bits[:20])
  88. print("Mean of bits (should be around 0.5): ", np.mean(bits))
  89. def SP(bits):
  90. return bits.reshape((len(dataCarriers), mu))
  91. bits_SP = SP(bits)
  92. print("First 5 bit groups")
  93. print(bits_SP[:5, :])
  94. def generate_random_inputs(num_of_blocks, return_vals=False):
  95. """
  96. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  97. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  98. consecutively to model ISI. The central message in a block is returned as the label for training.
  99. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  100. """
  101. rand_int = np.random.randint(16, size=(num_of_blocks, 1))
  102. # cat = [np.arange(self.cardinality)]
  103. # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  104. # out = enc.fit_transform(rand_int)
  105. # for symbol in rand_int:
  106. # out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
  107. # t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
  108. # mid_idx = int((self.messages_per_block - 1) / 2)
  109. # if return_vals:
  110. # out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  111. # return out_val, out_arr, out_arr[:, mid_idx, :]
  112. return rand_int
  113. bits_SP1 = generate_random_inputs(num_of_blocks=1000).astype('uint8')
  114. # bits_SP11 = np.unpackbits(bits_SP1, axis=1)
  115. def Mapping(bits):
  116. return np.array([mapping_table[tuple(b)] for b in bits])
  117. def Mapping_dec(bits):
  118. return np.array([mapping_table_dec[tuple(b)] for b in bits])
  119. QAM = Mapping(bits_SP)
  120. QAM1 = Mapping_dec(bits_SP1)
  121. print("First 5 QAM symbols and bits:")
  122. print(bits_SP[:5, :])
  123. print(QAM[:5])
  124. def OFDM_symbol(QAM_payload):
  125. symbol = np.zeros(K, dtype=complex) # the overall K subcarriers
  126. symbol[pilotCarriers] = pilotValue # allocate the pilot subcarriers
  127. symbol[dataCarriers] = QAM_payload # allocate the pilot subcarriers
  128. return symbol
  129. OFDM_data = OFDM_symbol(QAM)
  130. print("Number of OFDM carriers in frequency domain: ", len(OFDM_data))
  131. def IDFT(OFDM_data):
  132. return np.fft.ifft(OFDM_data)
  133. OFDM_time = IDFT(OFDM_data)
  134. print("Number of OFDM samples in time-domain before CP: ", len(OFDM_time))
  135. def addCP(OFDM_time):
  136. cp = OFDM_time[-CP:] # take the last CP samples ...
  137. return np.hstack([cp, OFDM_time]) # ... and add them to the beginning
  138. OFDM_withCP = addCP(OFDM_time)
  139. print("Number of OFDM samples in time domain with CP: ", len(OFDM_withCP))
  140. def channel(signal):
  141. convolved = np.convolve(signal, channelResponse)
  142. signal_power = np.mean(abs(convolved ** 2))
  143. sigma2 = signal_power * 10 ** (-SNRdb / 10) # calculate noise power based on signal power and SNR
  144. print("RX Signal power: %.4f. Noise power: %.4f" % (signal_power, sigma2))
  145. # Generate complex noise with given variance
  146. noise = np.sqrt(sigma2 / 2) * (np.random.randn(*convolved.shape) + 1j * np.random.randn(*convolved.shape))
  147. return convolved + noise
  148. OFDM_TX = OFDM_withCP
  149. OFDM_RX = channel(OFDM_TX)
  150. # OFDM_RX1 = optical_channel(OFDM_TX).numpy
  151. plt.figure(figsize=(8, 2))
  152. plt.plot(abs(OFDM_TX), label='TX signal')
  153. plt.plot(abs(OFDM_RX), label='RX signal')
  154. plt.legend(fontsize=10)
  155. plt.xlabel('Time')
  156. plt.ylabel('$|x(t)|$')
  157. plt.grid(True)
  158. def removeCP(signal):
  159. return signal[CP:(CP + K)]
  160. OFDM_RX_noCP = removeCP(OFDM_RX)
  161. def DFT(OFDM_RX):
  162. return np.fft.fft(OFDM_RX)
  163. OFDM_demod = DFT(OFDM_RX_noCP)
  164. def channelEstimate(OFDM_demod):
  165. pilots = OFDM_demod[pilotCarriers] # extract the pilot values from the RX signal
  166. Hest_at_pilots = pilots / pilotValue # divide by the transmitted pilot values
  167. # Perform interpolation between the pilot carriers to get an estimate
  168. # of the channel in the data carriers. Here, we interpolate absolute value and phase
  169. # separately
  170. Hest_abs = interpolate.interp1d(pilotCarriers, abs(Hest_at_pilots), kind='linear')(allCarriers)
  171. Hest_phase = interpolate.interp1d(pilotCarriers, np.angle(Hest_at_pilots), kind='linear')(allCarriers)
  172. Hest = Hest_abs * np.exp(1j * Hest_phase)
  173. plt.plot(allCarriers, abs(H_exact), label='Correct Channel')
  174. plt.stem(pilotCarriers, abs(Hest_at_pilots), label='Pilot estimates')
  175. plt.plot(allCarriers, abs(Hest), label='Estimated channel via interpolation')
  176. plt.grid(True)
  177. plt.xlabel('Carrier index')
  178. plt.ylabel('$|H(f)|$')
  179. plt.legend(fontsize=10)
  180. plt.ylim(0, 2)
  181. return Hest
  182. Hest = channelEstimate(OFDM_demod)
  183. def equalize(OFDM_demod, Hest):
  184. return OFDM_demod / Hest
  185. equalized_Hest = equalize(OFDM_demod, Hest)
  186. def get_payload(equalized):
  187. return equalized[dataCarriers]
  188. QAM_est = get_payload(equalized_Hest)
  189. plt.plot(QAM_est.real, QAM_est.imag, 'bo');
  190. def Demapping(QAM):
  191. # array of possible constellation points
  192. constellation = np.array([x for x in demapping_table.keys()])
  193. # calculate distance of each RX point to each possible point
  194. dists = abs(QAM.reshape((-1, 1)) - constellation.reshape((1, -1)))
  195. # for each element in QAM, choose the index in constellation
  196. # that belongs to the nearest constellation point
  197. const_index = dists.argmin(axis=1)
  198. # get back the real constellation point
  199. hardDecision = constellation[const_index]
  200. # transform the constellation point into the bit groups
  201. return np.vstack([demapping_table[C] for C in hardDecision]), hardDecision
  202. PS_est, hardDecision = Demapping(QAM_est)
  203. for qam, hard in zip(QAM_est, hardDecision):
  204. plt.plot([qam.real, hard.real], [qam.imag, hard.imag], 'b-o');
  205. plt.plot(hardDecision.real, hardDecision.imag, 'ro')
  206. def PS(bits):
  207. return bits.reshape((-1,))
  208. bits_est = PS(PS_est)
  209. print("Obtained Bit error rate: ", np.sum(abs(bits - bits_est)) / len(bits))