end_to_end.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import math
  2. from tensorflow import keras
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from matplotlib import collections as matcoll
  7. from sklearn.preprocessing import OneHotEncoder
  8. from tensorflow.keras import layers, losses
  9. class ExtractCentralMessage(layers.Layer):
  10. def __init__(self, messages_per_block, samples_per_symbol):
  11. """
  12. :param messages_per_block: Total number of messages in transmission block
  13. :param samples_per_symbol: Number of samples per transmitted symbol
  14. """
  15. super(ExtractCentralMessage, self).__init__()
  16. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  17. i = np.identity(samples_per_symbol)
  18. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  19. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  20. temp_w[begin:end, :] = i
  21. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  22. def call(self, inputs, **kwargs):
  23. return tf.matmul(inputs, self.w)
  24. class DigitizationLayer(layers.Layer):
  25. def __init__(self,
  26. fs,
  27. num_of_samples,
  28. lpf_cutoff=32e9,
  29. q_stddev=0.1):
  30. """
  31. :param fs: Sampling frequency of the simulation in Hz
  32. :param num_of_samples: Total number of samples in the input
  33. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  34. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  35. """
  36. super(DigitizationLayer, self).__init__()
  37. self.noise_layer = layers.GaussianNoise(q_stddev)
  38. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  39. temp = np.ones(freq.shape)
  40. for idx, val in np.ndenumerate(freq):
  41. if np.abs(val) > lpf_cutoff:
  42. temp[idx] = 0
  43. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  44. def call(self, inputs, **kwargs):
  45. complex_in = tf.cast(inputs, dtype=tf.complex64)
  46. val_f = tf.signal.fft(complex_in)
  47. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  48. filtered_t = tf.signal.ifft(filtered_f)
  49. real_t = tf.cast(filtered_t, dtype=tf.float32)
  50. noisy = self.noise_layer.call(real_t, training=True)
  51. return noisy
  52. class OpticalChannel(layers.Layer):
  53. def __init__(self,
  54. fs,
  55. num_of_samples,
  56. dispersion_factor,
  57. fiber_length,
  58. lpf_cutoff=32e9,
  59. rx_stddev=0.01,
  60. q_stddev=0.01):
  61. """
  62. :param fs: Sampling frequency of the simulation in Hz
  63. :param num_of_samples: Total number of samples in the input
  64. :param dispersion_factor: Dispersion factor in s^2/km
  65. :param fiber_length: Length of fiber to model in km
  66. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  67. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  68. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  69. """
  70. super(OpticalChannel, self).__init__()
  71. self.noise_layer = layers.GaussianNoise(rx_stddev)
  72. self.digitization_layer = DigitizationLayer(fs=fs,
  73. num_of_samples=num_of_samples,
  74. lpf_cutoff=lpf_cutoff,
  75. q_stddev=q_stddev)
  76. self.flatten_layer = layers.Flatten()
  77. self.fs = fs
  78. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
  79. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  80. def call(self, inputs, **kwargs):
  81. # DAC LPF and noise
  82. dac_out = self.digitization_layer(inputs)
  83. # Chromatic Dispersion
  84. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  85. val_f = tf.signal.fft(complex_val)
  86. disp_f = tf.math.multiply(val_f, self.multiplier)
  87. disp_t = tf.signal.ifft(disp_f)
  88. # Squared-Law Detection
  89. pd_out = tf.square(tf.abs(disp_t))
  90. # Casting back to floatx
  91. real_val = tf.cast(pd_out, dtype=tf.float32)
  92. # Adding photo-diode receiver noise
  93. rx_signal = self.noise_layer.call(real_val, training=True)
  94. # ADC LPF and noise
  95. adc_out = self.digitization_layer(rx_signal)
  96. return adc_out
  97. class EndToEndAutoencoder(tf.keras.Model):
  98. def __init__(self,
  99. cardinality,
  100. samples_per_symbol,
  101. messages_per_block,
  102. channel):
  103. """
  104. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  105. :param samples_per_symbol: Number of samples per transmitted symbol
  106. :param messages_per_block: Total number of messages in transmission block
  107. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  108. """
  109. super(EndToEndAutoencoder, self).__init__()
  110. # Labelled M in paper
  111. self.cardinality = cardinality
  112. # Labelled n in paper
  113. self.samples_per_symbol = samples_per_symbol
  114. # Labelled N in paper
  115. if messages_per_block % 2 == 0:
  116. messages_per_block += 1
  117. self.messages_per_block = messages_per_block
  118. # Channel Model Layer
  119. if isinstance(channel, layers.Layer):
  120. self.channel = tf.keras.Sequential([
  121. layers.Flatten(),
  122. channel,
  123. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  124. ])
  125. else:
  126. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  127. # Encoding Neural Network
  128. self.encoder = tf.keras.Sequential([
  129. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  130. layers.Dense(2 * self.cardinality, activation='relu'),
  131. layers.Dense(2 * self.cardinality, activation='relu'),
  132. layers.Dense(self.samples_per_symbol),
  133. layers.ReLU(max_value=1.0)
  134. ])
  135. # Decoding Neural Network
  136. self.decoder = tf.keras.Sequential([
  137. layers.Dense(self.samples_per_symbol, activation='relu'),
  138. layers.Dense(2 * self.cardinality, activation='relu'),
  139. layers.Dense(2 * self.cardinality, activation='relu'),
  140. layers.Dense(self.cardinality, activation='softmax')
  141. ])
  142. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  143. """
  144. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  145. consecutively to model ISI. The central message in a block is returned as the label for training.
  146. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  147. """
  148. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  149. cat = [np.arange(self.cardinality)]
  150. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  151. out = enc.fit_transform(rand_int)
  152. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  153. mid_idx = int((self.messages_per_block-1)/2)
  154. if return_vals:
  155. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  156. return out_val, out_arr, out_arr[:, mid_idx, :]
  157. return out_arr, out_arr[:, mid_idx, :]
  158. def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  159. """
  160. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  161. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  162. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  163. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  164. """
  165. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  166. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  167. opt = keras.optimizers.Adam(learning_rate=lr)
  168. self.compile(optimizer=opt,
  169. loss=losses.BinaryCrossentropy(),
  170. metrics=['accuracy'],
  171. loss_weights=None,
  172. weighted_metrics=None,
  173. run_eagerly=False
  174. )
  175. self.fit(x=X_train,
  176. y=y_train,
  177. batch_size=batch_size,
  178. epochs=1,
  179. shuffle=True,
  180. validation_data=(X_test, y_test)
  181. )
  182. def view_encoder(self):
  183. # Generate inputs for encoder
  184. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  185. mid_idx = int((self.messages_per_block-1)/2)
  186. idx = 0
  187. for msg in messages:
  188. msg[mid_idx, idx] = 1
  189. idx += 1
  190. # Pass input through encoder and select middle messages
  191. encoded = self.encoder(messages)
  192. enc_messages = encoded[:, mid_idx, :]
  193. # Compute subplot grid layout
  194. i = 0
  195. while 2**i < self.cardinality**0.5:
  196. i += 1
  197. num_x = int(2**i)
  198. num_y = int(self.cardinality / num_x)
  199. # Plot all symbols
  200. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
  201. t = np.arange(self.samples_per_symbol)
  202. if isinstance(self.channel.layers[1], OpticalChannel):
  203. t = t/self.channel.layers[1].fs
  204. sym_idx = 0
  205. for y in range(num_y):
  206. for x in range(num_x):
  207. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  208. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  209. sym_idx += 1
  210. for ax in axs.flat:
  211. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  212. for ax in axs.flat:
  213. ax.label_outer()
  214. plt.show()
  215. pass
  216. def view_sample_block(self):
  217. # Generate a random block of messages
  218. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  219. # Encode and flatten the messages
  220. enc = self.encoder(inp)
  221. flat_enc = layers.Flatten()(enc)
  222. # Instantiate LPF layer
  223. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  224. num_of_samples=self.messages_per_block*self.samples_per_symbol,
  225. q_stddev=0)
  226. # Apply LPF
  227. lpf_out = lpf(flat_enc)
  228. # Time axis
  229. t = np.arange(self.messages_per_block*self.samples_per_symbol)
  230. if isinstance(self.channel.layers[1], OpticalChannel):
  231. t = t / self.channel.layers[1].fs
  232. # Plot the concatenated symbols before and after LPF
  233. plt.figure(figsize=(2*self.messages_per_block, 6))
  234. for i in range(1, self.messages_per_block):
  235. plt.axvline(x=t[i*self.samples_per_symbol], color='black')
  236. plt.plot(t, flat_enc.numpy().T, 'x')
  237. plt.plot(t, lpf_out.numpy().T)
  238. plt.ylim((0, 1))
  239. plt.xlim((t.min(), t.max()))
  240. plt.title(str(val[0, :, 0]))
  241. plt.show()
  242. pass
  243. def call(self, inputs, training=None, mask=None):
  244. tx = self.encoder(inputs)
  245. rx = self.channel(tx)
  246. outputs = self.decoder(rx)
  247. return outputs
  248. if __name__ == '__main__':
  249. SAMPLING_FREQUENCY = 336e9
  250. CARDINALITY = 32
  251. SAMPLES_PER_SYMBOL = 24
  252. MESSAGES_PER_BLOCK = 9
  253. DISPERSION_FACTOR = -21.7 * 1e-24
  254. FIBER_LENGTH = 50
  255. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  256. num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
  257. dispersion_factor=DISPERSION_FACTOR,
  258. fiber_length=FIBER_LENGTH)
  259. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  260. samples_per_symbol=SAMPLES_PER_SYMBOL,
  261. messages_per_block=MESSAGES_PER_BLOCK,
  262. channel=optical_channel)
  263. ae_model.train(num_of_blocks=1e6, batch_size=100)
  264. ae_model.view_encoder()
  265. ae_model.view_sample_block()
  266. pass