end_to_end.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import math
  2. import keras
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from matplotlib import collections as matcoll
  7. from sklearn.preprocessing import OneHotEncoder
  8. from keras import layers, losses
  9. class ExtractCentralMessage(layers.Layer):
  10. def __init__(self, messages_per_block, samples_per_symbol):
  11. """
  12. :param messages_per_block: Total number of messages in transmission block
  13. :param samples_per_symbol: Number of samples per transmitted symbol
  14. """
  15. super(ExtractCentralMessage, self).__init__()
  16. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  17. i = np.identity(samples_per_symbol)
  18. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  19. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  20. temp_w[begin:end, :] = i
  21. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  22. def call(self, inputs, **kwargs):
  23. return tf.matmul(inputs, self.w)
  24. class AwgnChannel(layers.Layer):
  25. def __init__(self, rx_stddev=0.1):
  26. """
  27. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  28. """
  29. super(AwgnChannel, self).__init__()
  30. self.noise_layer = layers.GaussianNoise(rx_stddev)
  31. def call(self, inputs, **kwargs):
  32. return self.noise_layer.call(inputs, training=True)
  33. class DigitizationLayer(layers.Layer):
  34. def __init__(self,
  35. fs,
  36. num_of_samples,
  37. lpf_cutoff=32e9,
  38. q_stddev=0.1):
  39. """
  40. :param fs: Sampling frequency of the simulation in Hz
  41. :param num_of_samples: Total number of samples in the input
  42. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  43. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  44. """
  45. super(DigitizationLayer, self).__init__()
  46. self.noise_layer = layers.GaussianNoise(q_stddev)
  47. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  48. temp = np.ones(freq.shape)
  49. for idx, val in np.ndenumerate(freq):
  50. if np.abs(val) > lpf_cutoff:
  51. temp[idx] = 0
  52. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  53. def call(self, inputs, **kwargs):
  54. complex_in = tf.cast(inputs, dtype=tf.complex64)
  55. val_f = tf.signal.fft(complex_in)
  56. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  57. filtered_t = tf.signal.ifft(filtered_f)
  58. real_t = tf.cast(filtered_t, dtype=tf.float32)
  59. noisy = self.noise_layer.call(real_t, training=True)
  60. return noisy
  61. class OpticalChannel(layers.Layer):
  62. def __init__(self,
  63. fs,
  64. num_of_samples,
  65. dispersion_factor,
  66. fiber_length,
  67. lpf_cutoff=32e9,
  68. rx_stddev=0.01,
  69. q_stddev=0.01):
  70. """
  71. :param fs: Sampling frequency of the simulation in Hz
  72. :param num_of_samples: Total number of samples in the input
  73. :param dispersion_factor: Dispersion factor in s^2/km
  74. :param fiber_length: Length of fiber to model in km
  75. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  76. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  77. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  78. """
  79. super(OpticalChannel, self).__init__()
  80. self.noise_layer = layers.GaussianNoise(rx_stddev)
  81. self.digitization_layer = DigitizationLayer(fs=fs,
  82. num_of_samples=num_of_samples,
  83. lpf_cutoff=lpf_cutoff,
  84. q_stddev=q_stddev)
  85. self.flatten_layer = layers.Flatten()
  86. self.fs = fs
  87. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
  88. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  89. def call(self, inputs, **kwargs):
  90. # DAC LPF and noise
  91. dac_out = self.digitization_layer(inputs)
  92. # Chromatic Dispersion
  93. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  94. val_f = tf.signal.fft(complex_val)
  95. disp_f = tf.math.multiply(val_f, self.multiplier)
  96. disp_t = tf.signal.ifft(disp_f)
  97. # Squared-Law Detection
  98. pd_out = tf.square(tf.abs(disp_t))
  99. # Casting back to floatx
  100. real_val = tf.cast(pd_out, dtype=tf.float32)
  101. # Adding photo-diode receiver noise
  102. rx_signal = self.noise_layer.call(real_val, training=True)
  103. # ADC LPF and noise
  104. adc_out = self.digitization_layer(rx_signal)
  105. return adc_out
  106. class EndToEndAutoencoder(tf.keras.Model):
  107. def __init__(self,
  108. cardinality,
  109. samples_per_symbol,
  110. messages_per_block,
  111. channel):
  112. """
  113. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  114. :param samples_per_symbol: Number of samples per transmitted symbol
  115. :param messages_per_block: Total number of messages in transmission block
  116. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  117. """
  118. super(EndToEndAutoencoder, self).__init__()
  119. # Labelled M in paper
  120. self.cardinality = cardinality
  121. # Labelled n in paper
  122. self.samples_per_symbol = samples_per_symbol
  123. # Labelled N in paper
  124. if messages_per_block % 2 == 0:
  125. messages_per_block += 1
  126. self.messages_per_block = messages_per_block
  127. # Channel Model Layer
  128. if isinstance(channel, layers.Layer):
  129. self.channel = tf.keras.Sequential([
  130. layers.Flatten(),
  131. channel,
  132. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  133. ])
  134. else:
  135. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  136. # Encoding Neural Network
  137. self.encoder = tf.keras.Sequential([
  138. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  139. layers.Dense(2 * self.cardinality, activation='relu'),
  140. layers.Dense(2 * self.cardinality, activation='relu'),
  141. layers.Dense(self.samples_per_symbol),
  142. layers.ReLU(max_value=1.0)
  143. ])
  144. # Decoding Neural Network
  145. self.decoder = tf.keras.Sequential([
  146. layers.Dense(self.samples_per_symbol, activation='relu'),
  147. layers.Dense(2 * self.cardinality, activation='relu'),
  148. layers.Dense(2 * self.cardinality, activation='relu'),
  149. layers.Dense(self.cardinality, activation='softmax')
  150. ])
  151. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  152. """
  153. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  154. consecutively to model ISI. The central message in a block is returned as the label for training.
  155. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  156. """
  157. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  158. cat = [np.arange(self.cardinality)]
  159. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  160. out = enc.fit_transform(rand_int)
  161. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  162. mid_idx = int((self.messages_per_block-1)/2)
  163. if return_vals:
  164. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  165. return out_val, out_arr, out_arr[:, mid_idx, :]
  166. return out_arr, out_arr[:, mid_idx, :]
  167. def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  168. """
  169. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  170. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  171. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  172. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  173. """
  174. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  175. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  176. opt = keras.optimizers.Adam(learning_rate=lr)
  177. self.compile(optimizer=opt,
  178. loss=losses.BinaryCrossentropy(),
  179. metrics=['accuracy'],
  180. loss_weights=None,
  181. weighted_metrics=None,
  182. run_eagerly=False
  183. )
  184. self.fit(x=X_train,
  185. y=y_train,
  186. batch_size=batch_size,
  187. epochs=1,
  188. shuffle=True,
  189. validation_data=(X_test, y_test)
  190. )
  191. def view_encoder(self):
  192. # Generate inputs for encoder
  193. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  194. mid_idx = int((self.messages_per_block-1)/2)
  195. idx = 0
  196. for msg in messages:
  197. msg[mid_idx, idx] = 1
  198. idx += 1
  199. # Pass input through encoder and select middle messages
  200. encoded = self.encoder(messages)
  201. enc_messages = encoded[:, mid_idx, :]
  202. # Compute subplot grid layout
  203. i = 0
  204. while 2**i < self.cardinality**0.5:
  205. i += 1
  206. num_x = int(2**i)
  207. num_y = int(self.cardinality / num_x)
  208. # Plot all symbols
  209. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
  210. t = np.arange(self.samples_per_symbol)
  211. if isinstance(self.channel.layers[1], OpticalChannel):
  212. t = t/self.channel.layers[1].fs
  213. sym_idx = 0
  214. for y in range(num_y):
  215. for x in range(num_x):
  216. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  217. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  218. sym_idx += 1
  219. for ax in axs.flat:
  220. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  221. for ax in axs.flat:
  222. ax.label_outer()
  223. plt.show()
  224. pass
  225. def view_sample_block(self):
  226. # Generate a random block of messages
  227. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  228. # Encode and flatten the messages
  229. enc = self.encoder(inp)
  230. flat_enc = layers.Flatten()(enc)
  231. # Instantiate LPF layer
  232. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  233. num_of_samples=self.messages_per_block*self.samples_per_symbol,
  234. q_stddev=0)
  235. # Apply LPF
  236. lpf_out = lpf(flat_enc)
  237. # Time axis
  238. t = np.arange(self.messages_per_block*self.samples_per_symbol)
  239. if isinstance(self.channel.layers[1], OpticalChannel):
  240. t = t / self.channel.layers[1].fs
  241. # Plot the concatenated symbols before and after LPF
  242. plt.figure(figsize=(2*self.messages_per_block, 6))
  243. for i in range(1, self.messages_per_block):
  244. plt.axvline(x=t[i*self.samples_per_symbol], color='black')
  245. plt.plot(t, flat_enc.numpy().T, 'x')
  246. plt.plot(t, lpf_out.numpy().T)
  247. plt.ylim((0, 1))
  248. plt.xlim((t.min(), t.max()))
  249. plt.title(str(val[0, :, 0]))
  250. plt.show()
  251. pass
  252. def call(self, inputs, training=None, mask=None):
  253. tx = self.encoder(inputs)
  254. rx = self.channel(tx)
  255. outputs = self.decoder(rx)
  256. return outputs
  257. if __name__ == '__main__':
  258. SAMPLING_FREQUENCY = 336e9
  259. CARDINALITY = 32
  260. SAMPLES_PER_SYMBOL = 24
  261. MESSAGES_PER_BLOCK = 9
  262. DISPERSION_FACTOR = -21.7 * 1e-24
  263. FIBER_LENGTH = 50
  264. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  265. num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
  266. dispersion_factor=DISPERSION_FACTOR,
  267. fiber_length=FIBER_LENGTH)
  268. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  269. samples_per_symbol=SAMPLES_PER_SYMBOL,
  270. messages_per_block=MESSAGES_PER_BLOCK,
  271. channel=optical_channel)
  272. ae_model.train(num_of_blocks=1e6, batch_size=100)
  273. ae_model.view_encoder()
  274. ae_model.view_sample_block()
  275. pass