OFDM.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import scipy
  4. from scipy import interpolate
  5. import tensorflow as tf
  6. from tensorflow.keras import layers, losses
  7. import math
  8. class ExtractCentralMessage(layers.Layer):
  9. def __init__(self, messages_per_block, samples_per_symbol):
  10. """
  11. A keras layer that extracts the central message(symbol) in a block.
  12. :param messages_per_block: Total number of messages in transmission block
  13. :param samples_per_symbol: Number of samples per transmitted symbol
  14. """
  15. super(ExtractCentralMessage, self).__init__()
  16. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  17. i = np.identity(samples_per_symbol)
  18. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  19. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  20. temp_w[begin:end, :] = i
  21. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  22. def call(self, inputs, **kwargs):
  23. return tf.matmul(inputs, self.w)
  24. class AwgnChannel(layers.Layer):
  25. def __init__(self, rx_stddev=0.1):
  26. """
  27. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  28. being applied every time the call function is called.
  29. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  30. """
  31. super(AwgnChannel, self).__init__()
  32. self.noise_layer = layers.GaussianNoise(rx_stddev)
  33. def call(self, inputs, **kwargs):
  34. return self.noise_layer.call(inputs, training=True)
  35. class DigitizationLayer(layers.Layer):
  36. def __init__(self,
  37. fs,
  38. num_of_samples,
  39. lpf_cutoff=32e9,
  40. q_stddev=0.1):
  41. """
  42. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  43. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  44. :param fs: Sampling frequency of the simulation in Hz
  45. :param num_of_samples: Total number of samples in the input
  46. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  47. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  48. """
  49. super(DigitizationLayer, self).__init__()
  50. self.noise_layer = layers.GaussianNoise(q_stddev)
  51. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  52. temp = np.ones(freq.shape)
  53. for idx, val in np.ndenumerate(freq):
  54. if np.abs(val) > lpf_cutoff:
  55. temp[idx] = 0
  56. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  57. def call(self, inputs, **kwargs):
  58. complex_in = tf.cast(inputs, dtype=tf.complex64)
  59. val_f = tf.signal.fft(complex_in)
  60. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  61. filtered_t = tf.signal.ifft(filtered_f)
  62. real_t = tf.cast(filtered_t, dtype=tf.float32)
  63. noisy = self.noise_layer.call(real_t, training=True)
  64. return noisy
  65. class OpticalChannel(layers.Layer):
  66. def __init__(self,
  67. fs,
  68. num_of_samples,
  69. dispersion_factor,
  70. fiber_length,
  71. lpf_cutoff=32e9,
  72. rx_stddev=0.03,
  73. q_stddev=0.01):
  74. """
  75. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  76. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  77. :param fs: Sampling frequency of the simulation in Hz
  78. :param num_of_samples: Total number of samples in the input
  79. :param dispersion_factor: Dispersion factor in s^2/km
  80. :param fiber_length: Length of fiber to model in km
  81. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  82. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  83. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  84. """
  85. super(OpticalChannel, self).__init__()
  86. self.noise_layer = layers.GaussianNoise(rx_stddev)
  87. self.digitization_layer = DigitizationLayer(fs=fs,
  88. num_of_samples=num_of_samples,
  89. lpf_cutoff=lpf_cutoff,
  90. q_stddev=q_stddev)
  91. self.flatten_layer = layers.Flatten()
  92. self.fs = fs
  93. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128)
  94. self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq))
  95. def call(self, inputs, **kwargs):
  96. # DAC LPF and noise
  97. dac_out = self.digitization_layer(inputs)
  98. # Chromatic Dispersion
  99. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  100. val_f = tf.signal.fft(complex_val)
  101. disp_f = tf.math.multiply(val_f, self.multiplier)
  102. disp_t = tf.signal.ifft(disp_f)
  103. # Squared-Law Detection
  104. pd_out = tf.square(tf.abs(disp_t))
  105. # Casting back to floatx
  106. real_val = tf.cast(pd_out, dtype=tf.float32)
  107. # Adding photo-diode receiver noise
  108. rx_signal = self.noise_layer.call(real_val, training=True)
  109. # ADC LPF and noise
  110. adc_out = self.digitization_layer(rx_signal)
  111. return adc_out
  112. SAMPLING_FREQUENCY = 336e9
  113. CARDINALITY = 4
  114. SAMPLES_PER_SYMBOL = 128
  115. MESSAGES_PER_BLOCK = 9
  116. DISPERSION_FACTOR = -21.7 * 1e-24
  117. FIBER_LENGTH = 50
  118. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  119. num_of_samples=80,
  120. dispersion_factor=DISPERSION_FACTOR,
  121. fiber_length=FIBER_LENGTH)
  122. K = 64 # number of OFDM subcarriers
  123. CP = K//4 # length of the cyclic prefix: 25% of the block
  124. P = 8 # number of pilot carriers per OFDM block
  125. pilotValue = 3+3j # The known value each pilot transmits
  126. allCarriers = np.arange(K) # indices of all subcarriers ([0, 1, ... K-1])
  127. pilotCarriers = allCarriers[::K//P] # Pilots is every (K/P)th carrier.
  128. # For convenience of channel estimation, let's make the last carriers also be a pilot
  129. pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])
  130. P = P+1
  131. # data carriers are all remaining carriers
  132. dataCarriers = np.delete(allCarriers, pilotCarriers)
  133. print ("allCarriers: %s" % allCarriers)
  134. print ("pilotCarriers: %s" % pilotCarriers)
  135. print ("dataCarriers: %s" % dataCarriers)
  136. plt.plot(pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot')
  137. plt.plot(dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data')
  138. mu = 4 # bits per symbol (i.e. 16QAM)
  139. payloadBits_per_OFDM = len(dataCarriers)*mu # number of payload bits per OFDM symbol
  140. mapping_table = {
  141. (0,0,0,0) : -3-3j,
  142. (0,0,0,1) : -3-1j,
  143. (0,0,1,0) : -3+3j,
  144. (0,0,1,1) : -3+1j,
  145. (0,1,0,0) : -1-3j,
  146. (0,1,0,1) : -1-1j,
  147. (0,1,1,0) : -1+3j,
  148. (0,1,1,1) : -1+1j,
  149. (1,0,0,0) : 3-3j,
  150. (1,0,0,1) : 3-1j,
  151. (1,0,1,0) : 3+3j,
  152. (1,0,1,1) : 3+1j,
  153. (1,1,0,0) : 1-3j,
  154. (1,1,0,1) : 1-1j,
  155. (1,1,1,0) : 1+3j,
  156. (1,1,1,1) : 1+1j
  157. }
  158. mapping_table_dec = {
  159. (0) : -3-3j,
  160. (1) : -3-1j,
  161. (2) : -3+3j,
  162. (3) : -3+1j,
  163. (4) : -1-3j,
  164. (5) : -1-1j,
  165. (6) : -1+3j,
  166. (7) : -1+1j,
  167. (8) : 3-3j,
  168. (9) : 3-1j,
  169. (10) : 3+3j,
  170. (11) : 3+1j,
  171. (12) : 1-3j,
  172. (13) : 1-1j,
  173. (14) : 1+3j,
  174. (15) : 1+1j
  175. }
  176. for b3 in [0, 1]:
  177. for b2 in [0, 1]:
  178. for b1 in [0, 1]:
  179. for b0 in [0, 1]:
  180. B = (b3, b2, b1, b0)
  181. Q = mapping_table[B]
  182. plt.plot(Q.real, Q.imag, 'bo')
  183. plt.text(Q.real, Q.imag+0.2, "".join(str(x) for x in B), ha='center')
  184. demapping_table = {v: k for k, v in mapping_table.items()}
  185. # Replace with our channel
  186. channelResponse = np.array([1, 0, 0.3+0.3j]) # the impulse response of the wireless channel
  187. H_exact = np.fft.fft(channelResponse, K)
  188. plt.plot(allCarriers, abs(H_exact))
  189. SNRdb = 25 # signal to noise-ratio in dB at the receiver
  190. # Here
  191. #water filling, gradient decent methods for optimising the symbol mapping, instead of 16 QAM
  192. bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM, ))
  193. print ("Bits count: ", len(bits))
  194. print ("First 20 bits: ", bits[:20])
  195. print ("Mean of bits (should be around 0.5): ", np.mean(bits))
  196. def SP(bits):
  197. return bits.reshape((len(dataCarriers), mu))
  198. bits_SP = SP(bits)
  199. print ("First 5 bit groups")
  200. print (bits_SP[:5,:])
  201. def generate_random_inputs(num_of_blocks, return_vals=False):
  202. """
  203. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  204. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  205. consecutively to model ISI. The central message in a block is returned as the label for training.
  206. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  207. """
  208. rand_int = np.random.randint(16, size=(num_of_blocks, 1))
  209. # cat = [np.arange(self.cardinality)]
  210. # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  211. # out = enc.fit_transform(rand_int)
  212. # for symbol in rand_int:
  213. #out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
  214. #t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
  215. #mid_idx = int((self.messages_per_block - 1) / 2)
  216. #if return_vals:
  217. # out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  218. # return out_val, out_arr, out_arr[:, mid_idx, :]
  219. return rand_int
  220. bits_SP1 = generate_random_inputs(num_of_blocks=1000).astype('uint8')
  221. #bits_SP11 = np.unpackbits(bits_SP1, axis=1)
  222. def Mapping(bits):
  223. return np.array([mapping_table[tuple(b)] for b in bits])
  224. def Mapping_dec(bits):
  225. return np.array([mapping_table_dec[tuple(b)] for b in bits])
  226. QAM = Mapping(bits_SP)
  227. QAM1 = Mapping_dec(bits_SP1)
  228. print("First 5 QAM symbols and bits:")
  229. print(bits_SP[:5,:])
  230. print(QAM[:5])
  231. def OFDM_symbol(QAM_payload):
  232. symbol = np.zeros(K, dtype=complex) # the overall K subcarriers
  233. symbol[pilotCarriers] = pilotValue # allocate the pilot subcarriers
  234. symbol[dataCarriers] = QAM_payload # allocate the pilot subcarriers
  235. return symbol
  236. OFDM_data = OFDM_symbol(QAM)
  237. print("Number of OFDM carriers in frequency domain: ", len(OFDM_data))
  238. def IDFT(OFDM_data):
  239. return np.fft.ifft(OFDM_data)
  240. OFDM_time = IDFT(OFDM_data)
  241. print("Number of OFDM samples in time-domain before CP: ", len(OFDM_time))
  242. def addCP(OFDM_time):
  243. cp = OFDM_time[-CP:] # take the last CP samples ...
  244. return np.hstack([cp, OFDM_time]) # ... and add them to the beginning
  245. OFDM_withCP = addCP(OFDM_time)
  246. print("Number of OFDM samples in time domain with CP: ", len(OFDM_withCP))
  247. def channel(signal):
  248. convolved = np.convolve(signal, channelResponse)
  249. signal_power = np.mean(abs(convolved ** 2))
  250. sigma2 = signal_power * 10 ** (-SNRdb / 10) # calculate noise power based on signal power and SNR
  251. print("RX Signal power: %.4f. Noise power: %.4f" % (signal_power, sigma2))
  252. # Generate complex noise with given variance
  253. noise = np.sqrt(sigma2 / 2) * (np.random.randn(*convolved.shape) + 1j * np.random.randn(*convolved.shape))
  254. return convolved + noise
  255. OFDM_TX = OFDM_withCP
  256. OFDM_RX = channel(OFDM_TX)
  257. #OFDM_RX1 = optical_channel(OFDM_TX).numpy
  258. plt.figure(figsize=(8, 2))
  259. plt.plot(abs(OFDM_TX), label='TX signal')
  260. plt.plot(abs(OFDM_RX), label='RX signal')
  261. plt.legend(fontsize=10)
  262. plt.xlabel('Time')
  263. plt.ylabel('$|x(t)|$')
  264. plt.grid(True)
  265. def removeCP(signal):
  266. return signal[CP:(CP+K)]
  267. OFDM_RX_noCP = removeCP(OFDM_RX)
  268. def DFT(OFDM_RX):
  269. return np.fft.fft(OFDM_RX)
  270. OFDM_demod = DFT(OFDM_RX_noCP)
  271. def channelEstimate(OFDM_demod):
  272. pilots = OFDM_demod[pilotCarriers] # extract the pilot values from the RX signal
  273. Hest_at_pilots = pilots / pilotValue # divide by the transmitted pilot values
  274. # Perform interpolation between the pilot carriers to get an estimate
  275. # of the channel in the data carriers. Here, we interpolate absolute value and phase
  276. # separately
  277. Hest_abs = interpolate.interp1d(pilotCarriers, abs(Hest_at_pilots), kind='linear')(allCarriers)
  278. Hest_phase = interpolate.interp1d(pilotCarriers, np.angle(Hest_at_pilots), kind='linear')(allCarriers)
  279. Hest = Hest_abs * np.exp(1j * Hest_phase)
  280. plt.plot(allCarriers, abs(H_exact), label='Correct Channel')
  281. plt.stem(pilotCarriers, abs(Hest_at_pilots), label='Pilot estimates')
  282. plt.plot(allCarriers, abs(Hest), label='Estimated channel via interpolation')
  283. plt.grid(True)
  284. plt.xlabel('Carrier index')
  285. plt.ylabel('$|H(f)|$')
  286. plt.legend(fontsize=10)
  287. plt.ylim(0, 2)
  288. return Hest
  289. Hest = channelEstimate(OFDM_demod)
  290. def equalize(OFDM_demod, Hest):
  291. return OFDM_demod / Hest
  292. equalized_Hest = equalize(OFDM_demod, Hest)
  293. def get_payload(equalized):
  294. return equalized[dataCarriers]
  295. QAM_est = get_payload(equalized_Hest)
  296. plt.plot(QAM_est.real, QAM_est.imag, 'bo');
  297. def Demapping(QAM):
  298. # array of possible constellation points
  299. constellation = np.array([x for x in demapping_table.keys()])
  300. # calculate distance of each RX point to each possible point
  301. dists = abs(QAM.reshape((-1, 1)) - constellation.reshape((1, -1)))
  302. # for each element in QAM, choose the index in constellation
  303. # that belongs to the nearest constellation point
  304. const_index = dists.argmin(axis=1)
  305. # get back the real constellation point
  306. hardDecision = constellation[const_index]
  307. # transform the constellation point into the bit groups
  308. return np.vstack([demapping_table[C] for C in hardDecision]), hardDecision
  309. PS_est, hardDecision = Demapping(QAM_est)
  310. for qam, hard in zip(QAM_est, hardDecision):
  311. plt.plot([qam.real, hard.real], [qam.imag, hard.imag], 'b-o');
  312. plt.plot(hardDecision.real, hardDecision.imag, 'ro')
  313. def PS(bits):
  314. return bits.reshape((-1,))
  315. bits_est = PS(PS_est)
  316. print("Obtained Bit error rate: ", np.sum(abs(bits-bits_est))/len(bits))