modulation_schemes.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.preprocessing import OneHotEncoder
  6. from tensorflow.keras import layers, losses
  7. import os
  8. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  9. class ExtractCentralMessage(layers.Layer):
  10. def __init__(self, messages_per_block, samples_per_symbol):
  11. """
  12. A keras layer that extracts the central message(symbol) in a block.
  13. :param messages_per_block: Total number of messages in transmission block
  14. :param samples_per_symbol: Number of samples per transmitted symbol
  15. """
  16. super(ExtractCentralMessage, self).__init__()
  17. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  18. i = np.identity(samples_per_symbol)
  19. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  20. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  21. temp_w[begin:end, :] = i
  22. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  23. def call(self, inputs, **kwargs):
  24. return tf.matmul(inputs, self.w)
  25. class AwgnChannel(layers.Layer):
  26. def __init__(self, rx_stddev=0.1):
  27. """
  28. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  29. being applied every time the call function is called.
  30. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  31. """
  32. super(AwgnChannel, self).__init__()
  33. self.noise_layer = layers.GaussianNoise(rx_stddev)
  34. def call(self, inputs, **kwargs):
  35. return self.noise_layer.call(inputs, training=True)
  36. class DigitizationLayer(layers.Layer):
  37. def __init__(self,
  38. fs,
  39. num_of_samples,
  40. lpf_cutoff=32e9,
  41. q_stddev=0.1):
  42. """
  43. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  44. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  45. :param fs: Sampling frequency of the simulation in Hz
  46. :param num_of_samples: Total number of samples in the input
  47. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  48. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  49. """
  50. super(DigitizationLayer, self).__init__()
  51. self.noise_layer = layers.GaussianNoise(q_stddev)
  52. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  53. temp = np.ones(freq.shape)
  54. for idx, val in np.ndenumerate(freq):
  55. if np.abs(val) > lpf_cutoff:
  56. temp[idx] = 0
  57. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  58. def call(self, inputs, **kwargs):
  59. complex_in = tf.cast(inputs, dtype=tf.complex64)
  60. val_f = tf.signal.fft(complex_in)
  61. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  62. filtered_t = tf.signal.ifft(filtered_f)
  63. real_t = tf.cast(filtered_t, dtype=tf.float32)
  64. noisy = self.noise_layer.call(real_t, training=True)
  65. return noisy
  66. class OpticalChannel(layers.Layer):
  67. def __init__(self,
  68. fs,
  69. num_of_samples,
  70. dispersion_factor,
  71. fiber_length,
  72. lpf_cutoff=32e9,
  73. rx_stddev=0.01,
  74. q_stddev=0.01):
  75. """
  76. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  77. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  78. :param fs: Sampling frequency of the simulation in Hz
  79. :param num_of_samples: Total number of samples in the input
  80. :param dispersion_factor: Dispersion factor in s^2/km
  81. :param fiber_length: Length of fiber to model in km
  82. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  83. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  84. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  85. """
  86. super(OpticalChannel, self).__init__()
  87. self.noise_layer = layers.GaussianNoise(rx_stddev)
  88. self.digitization_layer = DigitizationLayer(fs=fs,
  89. num_of_samples=num_of_samples,
  90. lpf_cutoff=lpf_cutoff,
  91. q_stddev=q_stddev)
  92. self.flatten_layer = layers.Flatten()
  93. self.fs = fs
  94. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128)
  95. self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq))
  96. def call(self, inputs, **kwargs):
  97. # DAC LPF and noise
  98. dac_out = self.digitization_layer(inputs)
  99. # Chromatic Dispersion
  100. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  101. val_f = tf.signal.fft(complex_val)
  102. disp_f = tf.math.multiply(val_f, self.multiplier)
  103. disp_t = tf.signal.ifft(disp_f)
  104. # Squared-Law Detection
  105. pd_out = tf.square(tf.abs(disp_t))
  106. # Casting back to floatx
  107. real_val = tf.cast(pd_out, dtype=tf.float32)
  108. # Adding photo-diode receiver noise
  109. rx_signal = self.noise_layer.call(real_val, training=True)
  110. # ADC LPF and noise
  111. adc_out = self.digitization_layer(rx_signal)
  112. return adc_out
  113. class ModulationModel(tf.keras.Model):
  114. def __init__(self,
  115. cardinality,
  116. samples_per_symbol,
  117. messages_per_block,
  118. channel):
  119. """
  120. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  121. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  122. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  123. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  124. :param samples_per_symbol: Number of samples per transmitted symbol
  125. :param messages_per_block: Total number of messages in transmission block
  126. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  127. """
  128. super(ModulationModel, self).__init__()
  129. # Labelled M in paper
  130. self.cardinality = cardinality
  131. # Labelled n in paper
  132. self.samples_per_symbol = samples_per_symbol
  133. # Labelled N in paper
  134. if messages_per_block % 2 == 0:
  135. messages_per_block += 1
  136. self.messages_per_block = messages_per_block
  137. # Channel Model Layer
  138. if isinstance(channel, layers.Layer):
  139. self.channel = tf.keras.Sequential([
  140. layers.Flatten(),
  141. channel,
  142. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  143. ])
  144. else:
  145. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  146. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  147. """
  148. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  149. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  150. consecutively to model ISI. The central message in a block is returned as the label for training.
  151. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  152. """
  153. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  154. #cat = [np.arange(self.cardinality)]
  155. #enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  156. #out = enc.fit_transform(rand_int)
  157. #for symbol in rand_int:
  158. out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
  159. t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
  160. mid_idx = int((self.messages_per_block - 1) / 2)
  161. if return_vals:
  162. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  163. return out_val, out_arr, out_arr[:, mid_idx, :]
  164. return t_out_arr, out_arr[:, mid_idx]
  165. def view_encoder(self):
  166. '''
  167. A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
  168. each image.
  169. '''
  170. # Generate inputs for encoder
  171. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  172. mid_idx = int((self.messages_per_block - 1) / 2)
  173. idx = 0
  174. for msg in messages:
  175. msg[mid_idx, idx] = 1
  176. idx += 1
  177. # Pass input through encoder and select middle messages
  178. encoded = self.encoder(messages)
  179. enc_messages = encoded[:, mid_idx, :]
  180. # Compute subplot grid layout
  181. i = 0
  182. while 2 ** i < self.cardinality ** 0.5:
  183. i += 1
  184. num_x = int(2 ** i)
  185. num_y = int(self.cardinality / num_x)
  186. # Plot all symbols
  187. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  188. t = np.arange(self.samples_per_symbol)
  189. if isinstance(self.channel.layers[1], OpticalChannel):
  190. t = t / self.channel.layers[1].fs
  191. sym_idx = 0
  192. for y in range(num_y):
  193. for x in range(num_x):
  194. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  195. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  196. sym_idx += 1
  197. for ax in axs.flat:
  198. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  199. for ax in axs.flat:
  200. ax.label_outer()
  201. plt.show()
  202. pass
  203. def view_sample_block(self):
  204. '''
  205. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  206. digitization layer without any quantization noise for the low pass filtering.
  207. '''
  208. # Generate a random block of messages
  209. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  210. # Encode and flatten the messages
  211. enc = self.encoder(inp)
  212. flat_enc = layers.Flatten()(enc)
  213. # Instantiate LPF layer
  214. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  215. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  216. q_stddev=0)
  217. # Apply LPF
  218. lpf_out = lpf(flat_enc)
  219. # Time axis
  220. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  221. if isinstance(self.channel.layers[1], OpticalChannel):
  222. t = t / self.channel.layers[1].fs
  223. # Plot the concatenated symbols before and after LPF
  224. plt.figure(figsize=(2 * self.messages_per_block, 6))
  225. for i in range(1, self.messages_per_block):
  226. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  227. plt.plot(t, flat_enc.numpy().T, 'x')
  228. plt.plot(t, lpf_out.numpy().T)
  229. plt.ylim((0, 1))
  230. plt.xlim((t.min(), t.max()))
  231. plt.title(str(val[0, :, 0]))
  232. plt.show()
  233. pass
  234. def plot(self, inputs, signal, size):
  235. t = np.arange(self.messages_per_block * self.samples_per_symbol * 3)
  236. plt.figure()
  237. plt.plot(t, signal.flatten()[:self.messages_per_block * self.samples_per_symbol * 3])
  238. plt.plot(t, inputs.flatten()[:self.messages_per_block * self.samples_per_symbol * 3])
  239. plt.show()
  240. def call(self, inputs, training=None, mask=None):
  241. tx = self.encoder(inputs)
  242. rx = self.channel(tx)
  243. outputs = self.decoder(rx)
  244. return outputs
  245. if __name__ == '__main__':
  246. SAMPLING_FREQUENCY = 336e9
  247. CARDINALITY = 4
  248. SAMPLES_PER_SYMBOL = 24
  249. MESSAGES_PER_BLOCK = 9
  250. DISPERSION_FACTOR = -21.7 * 1e-24
  251. FIBER_LENGTH = 50
  252. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  253. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  254. dispersion_factor=DISPERSION_FACTOR,
  255. fiber_length=FIBER_LENGTH)
  256. #channel_output = optical_channel(input)
  257. modulation_scheme = ModulationModel(cardinality=CARDINALITY,
  258. samples_per_symbol=SAMPLES_PER_SYMBOL,
  259. messages_per_block=MESSAGES_PER_BLOCK,
  260. channel=optical_channel)
  261. size = 1000
  262. inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
  263. outputs = optical_channel(inputs).numpy()
  264. modulation_scheme.plot(inputs, outputs, size)
  265. print("done")
  266. """
  267. ae_model.view_encoder()
  268. ae_model.view_sample_block()
  269. """
  270. pass