modulation_schemes.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.preprocessing import OneHotEncoder
  6. from tensorflow.keras import layers, losses
  7. from scipy.signal import bessel, lfilter, filtfilt
  8. import os
  9. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  10. class ExtractCentralMessage(layers.Layer):
  11. def __init__(self, messages_per_block, samples_per_symbol):
  12. """
  13. A keras layer that extracts the central message(symbol) in a block.
  14. :param messages_per_block: Total number of messages in transmission block
  15. :param samples_per_symbol: Number of samples per transmitted symbol
  16. """
  17. super(ExtractCentralMessage, self).__init__()
  18. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  19. i = np.identity(samples_per_symbol)
  20. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  21. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  22. temp_w[begin:end, :] = i
  23. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  24. def call(self, inputs, **kwargs):
  25. return tf.matmul(inputs, self.w)
  26. class AwgnChannel(layers.Layer):
  27. def __init__(self, rx_stddev=0.1):
  28. """
  29. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  30. being applied every time the call function is called.
  31. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  32. """
  33. super(AwgnChannel, self).__init__()
  34. self.noise_layer = layers.GaussianNoise(rx_stddev)
  35. def call(self, inputs, **kwargs):
  36. return self.noise_layer.call(inputs, training=True)
  37. class DigitizationLayer(layers.Layer):
  38. def __init__(self,
  39. fs,
  40. num_of_samples,
  41. lpf_cutoff=32e9,
  42. q_stddev=0.1):
  43. """
  44. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  45. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  46. :param fs: Sampling frequency of the simulation in Hz
  47. :param num_of_samples: Total number of samples in the input
  48. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  49. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  50. """
  51. super(DigitizationLayer, self).__init__()
  52. self.noise_layer = layers.GaussianNoise(q_stddev)
  53. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  54. temp = np.ones(freq.shape)
  55. for idx, val in np.ndenumerate(freq):
  56. if np.abs(val) > lpf_cutoff:
  57. temp[idx] = 0
  58. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  59. def call(self, inputs, **kwargs):
  60. complex_in = tf.cast(inputs, dtype=tf.complex64)
  61. val_f = tf.signal.fft(complex_in)
  62. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  63. filtered_t = tf.signal.ifft(filtered_f)
  64. real_t = tf.cast(filtered_t, dtype=tf.float32)
  65. noisy = self.noise_layer.call(real_t, training=True)
  66. return noisy
  67. class OpticalChannel(layers.Layer):
  68. def __init__(self,
  69. fs,
  70. num_of_samples,
  71. dispersion_factor,
  72. fiber_length,
  73. lpf_cutoff=32e9,
  74. rx_stddev=0.03,
  75. q_stddev=0.01):
  76. """
  77. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  78. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  79. :param fs: Sampling frequency of the simulation in Hz
  80. :param num_of_samples: Total number of samples in the input
  81. :param dispersion_factor: Dispersion factor in s^2/km
  82. :param fiber_length: Length of fiber to model in km
  83. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  84. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  85. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  86. """
  87. super(OpticalChannel, self).__init__()
  88. self.noise_layer = layers.GaussianNoise(rx_stddev)
  89. self.digitization_layer = DigitizationLayer(fs=fs,
  90. num_of_samples=num_of_samples,
  91. lpf_cutoff=lpf_cutoff,
  92. q_stddev=q_stddev)
  93. self.flatten_layer = layers.Flatten()
  94. self.fs = fs
  95. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128)
  96. self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq))
  97. def call(self, inputs, **kwargs):
  98. # DAC LPF and noise
  99. dac_out = self.digitization_layer(inputs)
  100. # Chromatic Dispersion
  101. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  102. val_f = tf.signal.fft(complex_val)
  103. disp_f = tf.math.multiply(val_f, self.multiplier)
  104. disp_t = tf.signal.ifft(disp_f)
  105. # Squared-Law Detection
  106. pd_out = tf.square(tf.abs(disp_t))
  107. # Casting back to floatx
  108. real_val = tf.cast(pd_out, dtype=tf.float32)
  109. # Adding photo-diode receiver noise
  110. rx_signal = self.noise_layer.call(real_val, training=True)
  111. # ADC LPF and noise
  112. adc_out = self.digitization_layer(rx_signal)
  113. return adc_out
  114. class ModulationModel(tf.keras.Model):
  115. def __init__(self,
  116. cardinality,
  117. samples_per_symbol,
  118. messages_per_block,
  119. channel):
  120. """
  121. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  122. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  123. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  124. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  125. :param samples_per_symbol: Number of samples per transmitted symbol
  126. :param messages_per_block: Total number of messages in transmission block
  127. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  128. """
  129. super(ModulationModel, self).__init__()
  130. # Labelled M in paper
  131. self.cardinality = cardinality
  132. # Labelled n in paper
  133. self.samples_per_symbol = samples_per_symbol
  134. # Labelled N in paper
  135. if messages_per_block % 2 == 0:
  136. messages_per_block += 1
  137. self.messages_per_block = messages_per_block
  138. # Channel Model Layer
  139. if isinstance(channel, layers.Layer):
  140. self.channel = tf.keras.Sequential([
  141. layers.Flatten(),
  142. channel,
  143. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  144. ])
  145. else:
  146. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  147. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  148. """
  149. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  150. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  151. consecutively to model ISI. The central message in a block is returned as the label for training.
  152. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  153. """
  154. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  155. #cat = [np.arange(self.cardinality)]
  156. #enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  157. #out = enc.fit_transform(rand_int)
  158. #for symbol in rand_int:
  159. out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
  160. t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
  161. mid_idx = int((self.messages_per_block - 1) / 2)
  162. if return_vals:
  163. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  164. return out_val, out_arr, out_arr[:, mid_idx, :]
  165. return t_out_arr, out_arr
  166. def view_encoder(self):
  167. '''
  168. A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
  169. each image.
  170. '''
  171. # Generate inputs for encoder
  172. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  173. mid_idx = int((self.messages_per_block - 1) / 2)
  174. idx = 0
  175. for msg in messages:
  176. msg[mid_idx, idx] = 1
  177. idx += 1
  178. # Pass input through encoder and select middle messages
  179. encoded = self.encoder(messages)
  180. enc_messages = encoded[:, mid_idx, :]
  181. # Compute subplot grid layout
  182. i = 0
  183. while 2 ** i < self.cardinality ** 0.5:
  184. i += 1
  185. num_x = int(2 ** i)
  186. num_y = int(self.cardinality / num_x)
  187. # Plot all symbols
  188. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  189. t = np.arange(self.samples_per_symbol)
  190. if isinstance(self.channel.layers[1], OpticalChannel):
  191. t = t / self.channel.layers[1].fs
  192. sym_idx = 0
  193. for y in range(num_y):
  194. for x in range(num_x):
  195. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  196. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  197. sym_idx += 1
  198. for ax in axs.flat:
  199. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  200. for ax in axs.flat:
  201. ax.label_outer()
  202. plt.show()
  203. pass
  204. def view_sample_block(self):
  205. '''
  206. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  207. digitization layer without any quantization noise for the low pass filtering.
  208. '''
  209. # Generate a random block of messages
  210. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  211. # Encode and flatten the messages
  212. enc = self.encoder(inp)
  213. flat_enc = layers.Flatten()(enc)
  214. # Instantiate LPF layer
  215. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  216. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  217. q_stddev=0)
  218. # Apply LPF
  219. lpf_out = lpf(flat_enc)
  220. # Time axis
  221. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  222. if isinstance(self.channel.layers[1], OpticalChannel):
  223. t = t / self.channel.layers[1].fs
  224. # Plot the concatenated symbols before and after LPF
  225. plt.figure(figsize=(2 * self.messages_per_block, 6))
  226. for i in range(1, self.messages_per_block):
  227. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  228. plt.plot(t, flat_enc.numpy().T, 'x')
  229. plt.plot(t, lpf_out.numpy().T)
  230. plt.ylim((0, 1))
  231. plt.xlim((t.min(), t.max()))
  232. plt.title(str(val[0, :, 0]))
  233. plt.show()
  234. pass
  235. def plot(self, inputs, signal):
  236. t = np.arange(self.messages_per_block * self.samples_per_symbol*1)
  237. frequency = SAMPLING_FREQUENCY*1e-9
  238. t = np.divide(t, frequency)
  239. plt.figure()
  240. plt.plot(t, signal.flatten()[:self.messages_per_block * self.samples_per_symbol*1], label='Received Signal')
  241. inputs = np.square(inputs)
  242. plt.plot(t, inputs.flatten()[:self.messages_per_block * self.samples_per_symbol*1], label='Transmitted Signal')
  243. plt.xlabel('Time (ns)')
  244. plt.ylabel('Power')
  245. plt.title('Effect of Chromatic Dispersion and Noise on Generated Signal')
  246. plt.legend()
  247. plt.show()
  248. def demodulate(self, validation, outputs):
  249. symbol_rate = SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL
  250. b, a = bessel(5, 8*(symbol_rate)/(SAMPLING_FREQUENCY/2), btype='low', analog=False)
  251. outputsfilt = filtfilt(b, a, outputs)
  252. demodulate = np.sqrt(outputsfilt)
  253. #modulation_scheme.plot(inputs, demodulate)
  254. average = np.mean(demodulate.reshape(1000, -1, SAMPLES_PER_SYMBOL), axis=2).flatten()
  255. validation = validation.flatten()
  256. decisions = []
  257. for symbol in average:
  258. if symbol <=0.5 or np.isnan(symbol):
  259. decisions.append(0)
  260. elif symbol >0.5 and symbol<=1.5:
  261. decisions.append(1)
  262. elif symbol > 1.5 and symbol <= 2.5:
  263. decisions.append(2)
  264. else:
  265. decisions.append(3)
  266. decisions = np.array(decisions)
  267. error = 0
  268. index = 0
  269. while index < len(validation):
  270. if validation[index] != decisions[index]:
  271. error += 1
  272. index += 1
  273. #print("ber = " + str(error/len(validation)))
  274. return error/len(validation)
  275. def call(self, inputs, training=None, mask=None):
  276. tx = self.encoder(inputs)
  277. rx = self.channel(tx)
  278. outputs = self.decoder(rx)
  279. return outputs
  280. def plot_output_graphs():
  281. inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
  282. outputs = optical_channel(inputs).numpy()
  283. b, a = bessel(3, 8 * (SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL) / (SAMPLING_FREQUENCY / 2), btype='low', analog=False)
  284. outputsfilt = filtfilt(b, a, outputs)
  285. input = np.sqrt(inputs)
  286. noisy = np.sqrt(outputs)
  287. demodulate = np.sqrt(outputsfilt)
  288. modulation_scheme.plot(input, noisy)
  289. modulation_scheme.plot(input, demodulate)
  290. #ber.append(modulation_scheme.demodulate(validation, outputs))
  291. def plot_fibre_length_vs_ber():
  292. lengths = np.linspace(5, 70, 1000)
  293. ber =[]
  294. for length in lengths:
  295. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  296. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  297. dispersion_factor=DISPERSION_FACTOR,
  298. fiber_length=length)
  299. inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
  300. outputs = optical_channel(inputs).numpy()
  301. ber.append(modulation_scheme.demodulate(validation, outputs))
  302. plt.figure()
  303. plt.semilogy(lengths, ber, label='ber')
  304. plt.xlabel('Fibre Length (km)')
  305. plt.ylabel('BER')
  306. plt.title('Effect of Fibre Length on BER of 4PAM System')
  307. # Show the major grid lines with dark grey lines
  308. plt.grid(b=True, which='major', color='#666666', linestyle='-')
  309. # Show the minor grid lines with very faint and almost transparent grey lines
  310. plt.minorticks_on()
  311. plt.grid(b=True, which='minor', color='#999999', linestyle='-', alpha=0.2)
  312. plt.legend()
  313. plt.show()
  314. if __name__ == '__main__':
  315. SAMPLING_FREQUENCY = 336e9
  316. CARDINALITY = 4
  317. SAMPLES_PER_SYMBOL = 128
  318. MESSAGES_PER_BLOCK = 9
  319. DISPERSION_FACTOR = -21.7 * 1e-24
  320. FIBER_LENGTH = 50
  321. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  322. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  323. dispersion_factor=DISPERSION_FACTOR,
  324. fiber_length=FIBER_LENGTH)
  325. modulation_scheme = ModulationModel(cardinality=CARDINALITY,
  326. samples_per_symbol=SAMPLES_PER_SYMBOL,
  327. messages_per_block=MESSAGES_PER_BLOCK,
  328. channel=optical_channel)
  329. size = 1000
  330. #inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
  331. #outputs = optical_channel(inputs).numpy()
  332. #decisions = modulation_scheme.demodulate(validation, outputs)
  333. #plot_output_graphs()
  334. plot_fibre_length_vs_ber()
  335. #modulation_scheme.plot(inputs, outputs, size)
  336. print("done")
  337. pass