end_to_end.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.preprocessing import OneHotEncoder
  6. from tensorflow.keras import layers, losses
  7. class ExtractCentralMessage(layers.Layer):
  8. def __init__(self, messages_per_block, samples_per_symbol):
  9. """
  10. A keras layer that extracts the central message(symbol) in a block.
  11. :param messages_per_block: Total number of messages in transmission block
  12. :param samples_per_symbol: Number of samples per transmitted symbol
  13. """
  14. super(ExtractCentralMessage, self).__init__()
  15. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  16. i = np.identity(samples_per_symbol)
  17. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  18. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  19. temp_w[begin:end, :] = i
  20. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  21. def call(self, inputs, **kwargs):
  22. return tf.matmul(inputs, self.w)
  23. class AwgnChannel(layers.Layer):
  24. def __init__(self, rx_stddev=0.1):
  25. """
  26. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  27. being applied every time the call function is called.
  28. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  29. """
  30. super(AwgnChannel, self).__init__()
  31. self.noise_layer = layers.GaussianNoise(rx_stddev)
  32. def call(self, inputs, **kwargs):
  33. return self.noise_layer.call(inputs, training=True)
  34. class DigitizationLayer(layers.Layer):
  35. def __init__(self,
  36. fs,
  37. num_of_samples,
  38. lpf_cutoff=32e9,
  39. q_stddev=0.1):
  40. """
  41. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  42. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  43. :param fs: Sampling frequency of the simulation in Hz
  44. :param num_of_samples: Total number of samples in the input
  45. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  46. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  47. """
  48. super(DigitizationLayer, self).__init__()
  49. self.noise_layer = layers.GaussianNoise(q_stddev)
  50. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  51. temp = np.ones(freq.shape)
  52. for idx, val in np.ndenumerate(freq):
  53. if np.abs(val) > lpf_cutoff:
  54. temp[idx] = 0
  55. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  56. def call(self, inputs, **kwargs):
  57. complex_in = tf.cast(inputs, dtype=tf.complex64)
  58. val_f = tf.signal.fft(complex_in)
  59. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  60. filtered_t = tf.signal.ifft(filtered_f)
  61. real_t = tf.cast(filtered_t, dtype=tf.float32)
  62. noisy = self.noise_layer.call(real_t, training=True)
  63. return noisy
  64. class OpticalChannel(layers.Layer):
  65. def __init__(self,
  66. fs,
  67. num_of_samples,
  68. dispersion_factor,
  69. fiber_length,
  70. lpf_cutoff=32e9,
  71. rx_stddev=0.01,
  72. q_stddev=0.01):
  73. """
  74. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  75. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  76. :param fs: Sampling frequency of the simulation in Hz
  77. :param num_of_samples: Total number of samples in the input
  78. :param dispersion_factor: Dispersion factor in s^2/km
  79. :param fiber_length: Length of fiber to model in km
  80. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  81. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  82. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  83. """
  84. super(OpticalChannel, self).__init__()
  85. self.noise_layer = layers.GaussianNoise(rx_stddev)
  86. self.digitization_layer = DigitizationLayer(fs=fs,
  87. num_of_samples=num_of_samples,
  88. lpf_cutoff=lpf_cutoff,
  89. q_stddev=q_stddev)
  90. self.flatten_layer = layers.Flatten()
  91. self.fs = fs
  92. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
  93. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  94. def call(self, inputs, **kwargs):
  95. # DAC LPF and noise
  96. dac_out = self.digitization_layer(inputs)
  97. # Chromatic Dispersion
  98. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  99. val_f = tf.signal.fft(complex_val)
  100. disp_f = tf.math.multiply(val_f, self.multiplier)
  101. disp_t = tf.signal.ifft(disp_f)
  102. # Squared-Law Detection
  103. pd_out = tf.square(tf.abs(disp_t))
  104. # Casting back to floatx
  105. real_val = tf.cast(pd_out, dtype=tf.float32)
  106. # Adding photo-diode receiver noise
  107. rx_signal = self.noise_layer.call(real_val, training=True)
  108. # ADC LPF and noise
  109. adc_out = self.digitization_layer(rx_signal)
  110. return adc_out
  111. class EndToEndAutoencoder(tf.keras.Model):
  112. def __init__(self,
  113. cardinality,
  114. samples_per_symbol,
  115. messages_per_block,
  116. channel):
  117. """
  118. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  119. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  120. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  121. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  122. :param samples_per_symbol: Number of samples per transmitted symbol
  123. :param messages_per_block: Total number of messages in transmission block
  124. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  125. """
  126. super(EndToEndAutoencoder, self).__init__()
  127. # Labelled M in paper
  128. self.cardinality = cardinality
  129. # Labelled n in paper
  130. self.samples_per_symbol = samples_per_symbol
  131. # Labelled N in paper
  132. if messages_per_block % 2 == 0:
  133. messages_per_block += 1
  134. self.messages_per_block = messages_per_block
  135. # Channel Model Layer
  136. if isinstance(channel, layers.Layer):
  137. self.channel = tf.keras.Sequential([
  138. layers.Flatten(),
  139. channel,
  140. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  141. ])
  142. else:
  143. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  144. # Encoding Neural Network
  145. self.encoder = tf.keras.Sequential([
  146. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  147. layers.Dense(2 * self.cardinality, activation='relu'),
  148. layers.Dense(2 * self.cardinality, activation='relu'),
  149. layers.Dense(self.samples_per_symbol),
  150. layers.ReLU(max_value=1.0)
  151. ])
  152. # Decoding Neural Network
  153. self.decoder = tf.keras.Sequential([
  154. layers.Dense(self.samples_per_symbol, activation='relu'),
  155. layers.Dense(2 * self.cardinality, activation='relu'),
  156. layers.Dense(2 * self.cardinality, activation='relu'),
  157. layers.Dense(self.cardinality, activation='softmax')
  158. ])
  159. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  160. """
  161. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  162. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  163. consecutively to model ISI. The central message in a block is returned as the label for training.
  164. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  165. """
  166. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  167. cat = [np.arange(self.cardinality)]
  168. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  169. out = enc.fit_transform(rand_int)
  170. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  171. mid_idx = int((self.messages_per_block-1)/2)
  172. if return_vals:
  173. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  174. return out_val, out_arr, out_arr[:, mid_idx, :]
  175. return out_arr, out_arr[:, mid_idx, :]
  176. def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  177. """
  178. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  179. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  180. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  181. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  182. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  183. """
  184. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  185. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  186. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  187. self.compile(optimizer=opt,
  188. loss=losses.BinaryCrossentropy(),
  189. metrics=['accuracy'],
  190. loss_weights=None,
  191. weighted_metrics=None,
  192. run_eagerly=False
  193. )
  194. self.fit(x=X_train,
  195. y=y_train,
  196. batch_size=batch_size,
  197. epochs=1,
  198. shuffle=True,
  199. validation_data=(X_test, y_test)
  200. )
  201. def view_encoder(self):
  202. '''
  203. A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
  204. each image.
  205. '''
  206. # Generate inputs for encoder
  207. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  208. mid_idx = int((self.messages_per_block-1)/2)
  209. idx = 0
  210. for msg in messages:
  211. msg[mid_idx, idx] = 1
  212. idx += 1
  213. # Pass input through encoder and select middle messages
  214. encoded = self.encoder(messages)
  215. enc_messages = encoded[:, mid_idx, :]
  216. # Compute subplot grid layout
  217. i = 0
  218. while 2**i < self.cardinality**0.5:
  219. i += 1
  220. num_x = int(2**i)
  221. num_y = int(self.cardinality / num_x)
  222. # Plot all symbols
  223. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
  224. t = np.arange(self.samples_per_symbol)
  225. if isinstance(self.channel.layers[1], OpticalChannel):
  226. t = t/self.channel.layers[1].fs
  227. sym_idx = 0
  228. for y in range(num_y):
  229. for x in range(num_x):
  230. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  231. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  232. sym_idx += 1
  233. for ax in axs.flat:
  234. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  235. for ax in axs.flat:
  236. ax.label_outer()
  237. plt.show()
  238. pass
  239. def view_sample_block(self):
  240. '''
  241. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  242. digitization layer without any quantization noise for the low pass filtering.
  243. '''
  244. # Generate a random block of messages
  245. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  246. # Encode and flatten the messages
  247. enc = self.encoder(inp)
  248. flat_enc = layers.Flatten()(enc)
  249. # Instantiate LPF layer
  250. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  251. num_of_samples=self.messages_per_block*self.samples_per_symbol,
  252. q_stddev=0)
  253. # Apply LPF
  254. lpf_out = lpf(flat_enc)
  255. # Time axis
  256. t = np.arange(self.messages_per_block*self.samples_per_symbol)
  257. if isinstance(self.channel.layers[1], OpticalChannel):
  258. t = t / self.channel.layers[1].fs
  259. # Plot the concatenated symbols before and after LPF
  260. plt.figure(figsize=(2*self.messages_per_block, 6))
  261. for i in range(1, self.messages_per_block):
  262. plt.axvline(x=t[i*self.samples_per_symbol], color='black')
  263. plt.plot(t, flat_enc.numpy().T, 'x')
  264. plt.plot(t, lpf_out.numpy().T)
  265. plt.ylim((0, 1))
  266. plt.xlim((t.min(), t.max()))
  267. plt.title(str(val[0, :, 0]))
  268. plt.show()
  269. pass
  270. def call(self, inputs, training=None, mask=None):
  271. tx = self.encoder(inputs)
  272. rx = self.channel(tx)
  273. outputs = self.decoder(rx)
  274. return outputs
  275. if __name__ == '__main__':
  276. SAMPLING_FREQUENCY = 336e9
  277. CARDINALITY = 32
  278. SAMPLES_PER_SYMBOL = 24
  279. MESSAGES_PER_BLOCK = 9
  280. DISPERSION_FACTOR = -21.7 * 1e-24
  281. FIBER_LENGTH = 50
  282. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  283. num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
  284. dispersion_factor=DISPERSION_FACTOR,
  285. fiber_length=FIBER_LENGTH)
  286. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  287. samples_per_symbol=SAMPLES_PER_SYMBOL,
  288. messages_per_block=MESSAGES_PER_BLOCK,
  289. channel=optical_channel)
  290. ae_model.train(num_of_blocks=1e6, batch_size=100)
  291. ae_model.view_encoder()
  292. ae_model.view_sample_block()
  293. pass