end_to_end.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import math
  2. import keras
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from sklearn.preprocessing import OneHotEncoder
  7. from keras import layers, losses
  8. class ExtractCentralMessage(layers.Layer):
  9. def __init__(self, messages_per_block, samples_per_symbol):
  10. """
  11. A keras layer that extracts the central message(symbol) in a block.
  12. :param messages_per_block: Total number of messages in transmission block
  13. :param samples_per_symbol: Number of samples per transmitted symbol
  14. """
  15. super(ExtractCentralMessage, self).__init__()
  16. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  17. i = np.identity(samples_per_symbol)
  18. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  19. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  20. temp_w[begin:end, :] = i
  21. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  22. def call(self, inputs, **kwargs):
  23. return tf.matmul(inputs, self.w)
  24. class AwgnChannel(layers.Layer):
  25. def __init__(self, rx_stddev=0.1):
  26. """
  27. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  28. being applied every time the call function is called.
  29. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  30. """
  31. super(AwgnChannel, self).__init__()
  32. self.noise_layer = layers.GaussianNoise(rx_stddev)
  33. def call(self, inputs, **kwargs):
  34. return self.noise_layer.call(inputs, training=True)
  35. class DigitizationLayer(layers.Layer):
  36. def __init__(self,
  37. fs,
  38. num_of_samples,
  39. lpf_cutoff=32e9,
  40. q_stddev=0.1):
  41. """
  42. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  43. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  44. :param fs: Sampling frequency of the simulation in Hz
  45. :param num_of_samples: Total number of samples in the input
  46. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  47. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  48. """
  49. super(DigitizationLayer, self).__init__()
  50. self.noise_layer = layers.GaussianNoise(q_stddev)
  51. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  52. temp = np.ones(freq.shape)
  53. for idx, val in np.ndenumerate(freq):
  54. if np.abs(val) > lpf_cutoff:
  55. temp[idx] = 0
  56. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  57. def call(self, inputs, **kwargs):
  58. complex_in = tf.cast(inputs, dtype=tf.complex64)
  59. val_f = tf.signal.fft(complex_in)
  60. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  61. filtered_t = tf.signal.ifft(filtered_f)
  62. real_t = tf.cast(filtered_t, dtype=tf.float32)
  63. noisy = self.noise_layer.call(real_t, training=True)
  64. return noisy
  65. class OpticalChannel(layers.Layer):
  66. def __init__(self,
  67. fs,
  68. num_of_samples,
  69. dispersion_factor,
  70. fiber_length,
  71. lpf_cutoff=32e9,
  72. rx_stddev=0.01,
  73. q_stddev=0.01):
  74. """
  75. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  76. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  77. :param fs: Sampling frequency of the simulation in Hz
  78. :param num_of_samples: Total number of samples in the input
  79. :param dispersion_factor: Dispersion factor in s^2/km
  80. :param fiber_length: Length of fiber to model in km
  81. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  82. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  83. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  84. """
  85. super(OpticalChannel, self).__init__()
  86. self.noise_layer = layers.GaussianNoise(rx_stddev)
  87. self.digitization_layer = DigitizationLayer(fs=fs,
  88. num_of_samples=num_of_samples,
  89. lpf_cutoff=lpf_cutoff,
  90. q_stddev=q_stddev)
  91. self.flatten_layer = layers.Flatten()
  92. self.fs = fs
  93. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
  94. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  95. def call(self, inputs, **kwargs):
  96. # DAC LPF and noise
  97. dac_out = self.digitization_layer(inputs)
  98. # Chromatic Dispersion
  99. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  100. val_f = tf.signal.fft(complex_val)
  101. disp_f = tf.math.multiply(val_f, self.multiplier)
  102. disp_t = tf.signal.ifft(disp_f)
  103. # Squared-Law Detection
  104. pd_out = tf.square(tf.abs(disp_t))
  105. # Casting back to floatx
  106. real_val = tf.cast(pd_out, dtype=tf.float32)
  107. # Adding photo-diode receiver noise
  108. rx_signal = self.noise_layer.call(real_val, training=True)
  109. # ADC LPF and noise
  110. adc_out = self.digitization_layer(rx_signal)
  111. return adc_out
  112. class EndToEndAutoencoder(tf.keras.Model):
  113. def __init__(self,
  114. cardinality,
  115. samples_per_symbol,
  116. messages_per_block,
  117. channel):
  118. """
  119. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  120. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  121. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  122. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  123. :param samples_per_symbol: Number of samples per transmitted symbol
  124. :param messages_per_block: Total number of messages in transmission block
  125. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  126. """
  127. super(EndToEndAutoencoder, self).__init__()
  128. # Labelled M in paper
  129. self.cardinality = cardinality
  130. # Labelled n in paper
  131. self.samples_per_symbol = samples_per_symbol
  132. # Labelled N in paper
  133. if messages_per_block % 2 == 0:
  134. messages_per_block += 1
  135. self.messages_per_block = messages_per_block
  136. # Channel Model Layer
  137. if isinstance(channel, layers.Layer):
  138. self.channel = tf.keras.Sequential([
  139. layers.Flatten(),
  140. channel,
  141. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  142. ])
  143. else:
  144. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  145. # Encoding Neural Network
  146. self.encoder = tf.keras.Sequential([
  147. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  148. layers.Dense(2 * self.cardinality, activation='relu'),
  149. layers.Dense(2 * self.cardinality, activation='relu'),
  150. layers.Dense(self.samples_per_symbol),
  151. layers.ReLU(max_value=1.0)
  152. ])
  153. # Decoding Neural Network
  154. self.decoder = tf.keras.Sequential([
  155. layers.Dense(self.samples_per_symbol, activation='relu'),
  156. layers.Dense(2 * self.cardinality, activation='relu'),
  157. layers.Dense(2 * self.cardinality, activation='relu'),
  158. layers.Dense(self.cardinality, activation='softmax')
  159. ])
  160. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  161. """
  162. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  163. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  164. consecutively to model ISI. The central message in a block is returned as the label for training.
  165. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  166. """
  167. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  168. cat = [np.arange(self.cardinality)]
  169. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  170. out = enc.fit_transform(rand_int)
  171. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  172. mid_idx = int((self.messages_per_block-1)/2)
  173. if return_vals:
  174. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  175. return out_val, out_arr, out_arr[:, mid_idx, :]
  176. return out_arr, out_arr[:, mid_idx, :]
  177. def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  178. """
  179. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  180. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  181. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  182. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  183. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  184. """
  185. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  186. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  187. opt = keras.optimizers.Adam(learning_rate=lr)
  188. self.compile(optimizer=opt,
  189. loss=losses.BinaryCrossentropy(),
  190. metrics=['accuracy'],
  191. loss_weights=None,
  192. weighted_metrics=None,
  193. run_eagerly=False
  194. )
  195. self.fit(x=X_train,
  196. y=y_train,
  197. batch_size=batch_size,
  198. epochs=1,
  199. shuffle=True,
  200. validation_data=(X_test, y_test)
  201. )
  202. def view_encoder(self):
  203. '''
  204. A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
  205. each image.
  206. '''
  207. # Generate inputs for encoder
  208. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  209. mid_idx = int((self.messages_per_block-1)/2)
  210. idx = 0
  211. for msg in messages:
  212. msg[mid_idx, idx] = 1
  213. idx += 1
  214. # Pass input through encoder and select middle messages
  215. encoded = self.encoder(messages)
  216. enc_messages = encoded[:, mid_idx, :]
  217. # Compute subplot grid layout
  218. i = 0
  219. while 2**i < self.cardinality**0.5:
  220. i += 1
  221. num_x = int(2**i)
  222. num_y = int(self.cardinality / num_x)
  223. # Plot all symbols
  224. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
  225. t = np.arange(self.samples_per_symbol)
  226. if isinstance(self.channel.layers[1], OpticalChannel):
  227. t = t/self.channel.layers[1].fs
  228. sym_idx = 0
  229. for y in range(num_y):
  230. for x in range(num_x):
  231. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  232. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  233. sym_idx += 1
  234. for ax in axs.flat:
  235. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  236. for ax in axs.flat:
  237. ax.label_outer()
  238. plt.show()
  239. pass
  240. def view_sample_block(self):
  241. '''
  242. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  243. digitization layer without any quantization noise for the low pass filtering.
  244. '''
  245. # Generate a random block of messages
  246. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  247. # Encode and flatten the messages
  248. enc = self.encoder(inp)
  249. flat_enc = layers.Flatten()(enc)
  250. # Instantiate LPF layer
  251. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  252. num_of_samples=self.messages_per_block*self.samples_per_symbol,
  253. q_stddev=0)
  254. # Apply LPF
  255. lpf_out = lpf(flat_enc)
  256. # Time axis
  257. t = np.arange(self.messages_per_block*self.samples_per_symbol)
  258. if isinstance(self.channel.layers[1], OpticalChannel):
  259. t = t / self.channel.layers[1].fs
  260. # Plot the concatenated symbols before and after LPF
  261. plt.figure(figsize=(2*self.messages_per_block, 6))
  262. for i in range(1, self.messages_per_block):
  263. plt.axvline(x=t[i*self.samples_per_symbol], color='black')
  264. plt.plot(t, flat_enc.numpy().T, 'x')
  265. plt.plot(t, lpf_out.numpy().T)
  266. plt.ylim((0, 1))
  267. plt.xlim((t.min(), t.max()))
  268. plt.title(str(val[0, :, 0]))
  269. plt.show()
  270. pass
  271. def call(self, inputs, training=None, mask=None):
  272. tx = self.encoder(inputs)
  273. rx = self.channel(tx)
  274. outputs = self.decoder(rx)
  275. return outputs
  276. if __name__ == '__main__':
  277. SAMPLING_FREQUENCY = 336e9
  278. CARDINALITY = 32
  279. SAMPLES_PER_SYMBOL = 24
  280. MESSAGES_PER_BLOCK = 9
  281. DISPERSION_FACTOR = -21.7 * 1e-24
  282. FIBER_LENGTH = 50
  283. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  284. num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
  285. dispersion_factor=DISPERSION_FACTOR,
  286. fiber_length=FIBER_LENGTH)
  287. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  288. samples_per_symbol=SAMPLES_PER_SYMBOL,
  289. messages_per_block=MESSAGES_PER_BLOCK,
  290. channel=optical_channel)
  291. ae_model.train(num_of_blocks=1e6, batch_size=100)
  292. ae_model.view_encoder()
  293. ae_model.view_sample_block()
  294. pass