end_to_end.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.preprocessing import OneHotEncoder
  7. from tensorflow.keras import layers, losses
  8. class ExtractCentralMessage(layers.Layer):
  9. def __init__(self, messages_per_block, samples_per_symbol):
  10. """
  11. :param messages_per_block: Total number of messages in transmission block
  12. :param samples_per_symbol: Number of samples per transmitted symbol
  13. """
  14. super(ExtractCentralMessage, self).__init__()
  15. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  16. i = np.identity(samples_per_symbol)
  17. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  18. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  19. temp_w[begin:end, :] = i
  20. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  21. def call(self, inputs, **kwargs):
  22. return tf.matmul(inputs, self.w)
  23. class DigitizationLayer(layers.Layer):
  24. def __init__(self,
  25. fs,
  26. num_of_samples,
  27. lpf_cutoff=32e9,
  28. q_stddev=0.1):
  29. """
  30. :param fs: Sampling frequency of the simulation in Hz
  31. :param num_of_samples: Total number of samples in the input
  32. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  33. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  34. """
  35. super(DigitizationLayer, self).__init__()
  36. self.noise_layer = layers.GaussianNoise(q_stddev)
  37. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  38. temp = np.ones(freq.shape)
  39. for idx, val in np.ndenumerate(freq):
  40. if np.abs(val) > lpf_cutoff:
  41. temp[idx] = 0
  42. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  43. def call(self, inputs, **kwargs):
  44. complex_in = tf.cast(inputs, dtype=tf.complex64)
  45. val_f = tf.signal.fft(complex_in)
  46. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  47. filtered_t = tf.signal.ifft(filtered_f)
  48. real_t = tf.cast(filtered_t, dtype=tf.float32)
  49. noisy = self.noise_layer.call(real_t, training=True)
  50. return noisy
  51. class OpticalChannel(layers.Layer):
  52. def __init__(self,
  53. fs,
  54. num_of_samples,
  55. dispersion_factor,
  56. fiber_length,
  57. lpf_cutoff=32e9,
  58. rx_stddev=0.01,
  59. q_stddev=0.01):
  60. """
  61. :param fs: Sampling frequency of the simulation in Hz
  62. :param num_of_samples: Total number of samples in the input
  63. :param dispersion_factor: Dispersion factor in s^2/km
  64. :param fiber_length: Length of fiber to model in km
  65. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  66. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  67. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  68. """
  69. super(OpticalChannel, self).__init__()
  70. self.noise_layer = layers.GaussianNoise(rx_stddev)
  71. self.digitization_layer = DigitizationLayer(fs=fs,
  72. num_of_samples=num_of_samples,
  73. lpf_cutoff=lpf_cutoff,
  74. q_stddev=q_stddev)
  75. self.flatten_layer = layers.Flatten()
  76. self.fs = fs
  77. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128)
  78. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  79. def call(self, inputs, **kwargs):
  80. # DAC LPF and noise
  81. dac_out = self.digitization_layer(inputs)
  82. # Chromatic Dispersion
  83. complex_val = tf.cast(dac_out, dtype=tf.complex128)
  84. val_f = tf.signal.fft(complex_val)
  85. disp_f = tf.math.multiply(val_f, self.multiplier)
  86. disp_t = tf.signal.ifft(disp_f)
  87. # Squared-Law Detection
  88. pd_out = tf.square(tf.abs(disp_t))
  89. # Casting back to floatx
  90. real_val = tf.cast(pd_out, dtype=tf.float32)
  91. # Adding photo-diode receiver noise
  92. rx_signal = self.noise_layer.call(real_val, training=True)
  93. # ADC LPF and noise
  94. adc_out = self.digitization_layer(rx_signal)
  95. return adc_out
  96. from tensorflow.keras import layers, losses
  97. from tensorflow.keras import backend as K
  98. from custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
  99. import itertools
  100. class EndToEndAutoencoder(tf.keras.Model):
  101. def __init__(self,
  102. cardinality,
  103. samples_per_symbol,
  104. messages_per_block,
  105. channel,
  106. bit_mapping=False):
  107. """
  108. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  109. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  110. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  111. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  112. :param samples_per_symbol: Number of samples per transmitted symbol
  113. :param messages_per_block: Total number of messages in transmission block
  114. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  115. """
  116. super(EndToEndAutoencoder, self).__init__()
  117. # Labelled M in paper
  118. self.cardinality = cardinality
  119. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  120. # Labelled n in paper
  121. self.samples_per_symbol = samples_per_symbol
  122. # Labelled N in paper
  123. if messages_per_block % 2 == 0:
  124. messages_per_block += 1
  125. self.messages_per_block = messages_per_block
  126. # Channel Model Layer
  127. if isinstance(channel, layers.Layer):
  128. self.channel = tf.keras.Sequential([
  129. layers.Flatten(),
  130. channel,
  131. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  132. ], name="channel_model")
  133. else:
  134. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  135. # Boolean identifying if bit mapping is to be learnt
  136. self.bit_mapping = bit_mapping
  137. # other parameters/metrics
  138. self.symbol_error_rate = None
  139. self.bit_error_rate = None
  140. self.snr = 20 * math.log(0.5/channel.rx_stddev, 10)
  141. # Model Hyper-parameters
  142. leaky_relu_alpha = 0
  143. relu_clip_val = 1.0
  144. # Layer configuration for the case when bit mapping is to be learnt
  145. if self.bit_mapping:
  146. encoding_layers = [
  147. layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
  148. BitsToSymbols(self.cardinality),
  149. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  150. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  151. # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  152. # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  153. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  154. layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  155. layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  156. ]
  157. decoding_layers = [
  158. layers.Dense(2 * self.cardinality),
  159. layers.LeakyReLU(alpha=leaky_relu_alpha),
  160. # layers.Dense(2 * self.cardinality),
  161. # layers.LeakyReLU(alpha=0.01),
  162. layers.Dense(self.bits_per_symbol, activation='sigmoid')
  163. ]
  164. # layer configuration for the case when only symbol mapping is to be learnt
  165. else:
  166. encoding_layers = [
  167. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  168. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  169. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  170. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  171. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  172. layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  173. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  174. # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  175. ]
  176. decoding_layers = [
  177. layers.Dense(2 * self.cardinality),
  178. layers.LeakyReLU(alpha=leaky_relu_alpha),
  179. layers.Dense(2 * self.cardinality),
  180. layers.LeakyReLU(alpha=leaky_relu_alpha),
  181. layers.Dense(self.cardinality, activation='softmax')
  182. ]
  183. # Encoding Neural Network
  184. self.encoder = tf.keras.Sequential([
  185. *encoding_layers
  186. ], name="encoding_model")
  187. # Decoding Neural Network
  188. self.decoder = tf.keras.Sequential([
  189. *decoding_layers
  190. ], name="decoding_model")
  191. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  192. """
  193. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  194. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  195. consecutively to model ISI. The central message in a block is returned as the label for training.
  196. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  197. """
  198. cat = [np.arange(self.cardinality)]
  199. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  200. mid_idx = int((self.messages_per_block - 1) / 2)
  201. if self.bit_mapping:
  202. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  203. out = rand_int
  204. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  205. if return_vals:
  206. return out_arr, out_arr, out_arr[:, mid_idx, :]
  207. else:
  208. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  209. out = enc.fit_transform(rand_int)
  210. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  211. if return_vals:
  212. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  213. return out_val, out_arr, out_arr[:, mid_idx, :]
  214. return out_arr, out_arr[:, mid_idx, :]
  215. def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
  216. """
  217. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  218. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  219. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  220. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  221. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  222. """
  223. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  224. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  225. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  226. # TODO: Investigate different optimizers (with different learning rates and other parameters)
  227. # SGD
  228. # RMSprop
  229. # Adam
  230. # Adadelta
  231. # Adagrad
  232. # Adamax
  233. # Nadam
  234. # Ftrl
  235. if self.bit_mapping:
  236. loss_fn = losses.BinaryCrossentropy()
  237. else:
  238. loss_fn = losses.CategoricalCrossentropy()
  239. self.compile(optimizer=opt,
  240. loss=loss_fn,
  241. metrics=['accuracy'],
  242. loss_weights=None,
  243. weighted_metrics=None,
  244. run_eagerly=False
  245. )
  246. history = self.fit(x=X_train,
  247. y=y_train,
  248. batch_size=batch_size,
  249. epochs=epochs,
  250. shuffle=True,
  251. validation_data=(X_test, y_test)
  252. )
  253. def test(self, num_of_blocks=1e4):
  254. X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
  255. y_out = self.call(X_test)
  256. y_pred = tf.argmax(y_out, axis=1)
  257. y_true = tf.argmax(y_test, axis=1)
  258. self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
  259. lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
  260. bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
  261. bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
  262. self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
  263. print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
  264. print("BIT ERROR RATE: {}".format(self.bit_error_rate))
  265. pass
  266. def view_encoder(self):
  267. '''
  268. A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
  269. each message/symbol.
  270. '''
  271. mid_idx = int((self.messages_per_block - 1) / 2)
  272. if self.bit_mapping:
  273. messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
  274. lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
  275. idx = 0
  276. for msg in messages:
  277. msg[mid_idx] = lst[idx]
  278. idx += 1
  279. else:
  280. # Generate inputs for encoder
  281. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  282. idx = 0
  283. for msg in messages:
  284. msg[mid_idx, idx] = 1
  285. idx += 1
  286. # Pass input through encoder and select middle messages
  287. encoded = self.encoder(messages)
  288. enc_messages = encoded[:, mid_idx, :]
  289. # Compute subplot grid layout
  290. i = 0
  291. while 2 ** i < self.cardinality ** 0.5:
  292. i += 1
  293. num_x = int(2 ** i)
  294. num_y = int(self.cardinality / num_x)
  295. # Plot all symbols
  296. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  297. t = np.arange(self.samples_per_symbol)
  298. if isinstance(self.channel.layers[1], OpticalChannel):
  299. t = t / self.channel.layers[1].fs
  300. sym_idx = 0
  301. for y in range(num_y):
  302. for x in range(num_x):
  303. axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
  304. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  305. sym_idx += 1
  306. for ax in axs.flat:
  307. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  308. for ax in axs.flat:
  309. ax.label_outer()
  310. plt.show()
  311. pass
  312. def view_sample_block(self):
  313. '''
  314. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  315. digitization layer without any quantization noise for the low pass filtering.
  316. '''
  317. # Generate a random block of messages
  318. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  319. # Encode and flatten the messages
  320. enc = self.encoder(inp)
  321. flat_enc = layers.Flatten()(enc)
  322. chan_out = self.channel.layers[1](flat_enc)
  323. # Instantiate LPF layer
  324. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  325. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  326. sig_avg=0)
  327. # Apply LPF
  328. lpf_out = lpf(flat_enc)
  329. # Time axis
  330. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  331. if isinstance(self.channel.layers[1], OpticalChannel):
  332. t = t / self.channel.layers[1].fs
  333. # Plot the concatenated symbols before and after LPF
  334. plt.figure(figsize=(2 * self.messages_per_block, 6))
  335. for i in range(1, self.messages_per_block):
  336. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  337. plt.plot(t, flat_enc.numpy().T, 'x')
  338. plt.plot(t, lpf_out.numpy().T)
  339. plt.plot(t, chan_out.numpy().flatten())
  340. plt.ylim((0, 1))
  341. plt.xlim((t.min(), t.max()))
  342. plt.title(str(val[0, :, 0]))
  343. plt.show()
  344. pass
  345. def call(self, inputs, training=None, mask=None):
  346. tx = self.encoder(inputs)
  347. rx = self.channel(tx)
  348. outputs = self.decoder(rx)
  349. return outputs
  350. SAMPLING_FREQUENCY = 336e9
  351. CARDINALITY = 32
  352. SAMPLES_PER_SYMBOL = 32
  353. MESSAGES_PER_BLOCK = 9
  354. DISPERSION_FACTOR = -21.7 * 1e-24
  355. FIBER_LENGTH = 0
  356. if __name__ == '__main__':
  357. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  358. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  359. dispersion_factor=DISPERSION_FACTOR,
  360. fiber_length=FIBER_LENGTH)
  361. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  362. samples_per_symbol=SAMPLES_PER_SYMBOL,
  363. messages_per_block=MESSAGES_PER_BLOCK,
  364. channel=optical_channel,
  365. bit_mapping=False)
  366. ae_model.train(num_of_blocks=1e5, epochs=5)
  367. ae_model.test()
  368. ae_model.view_encoder()
  369. ae_model.view_sample_block()
  370. # ae_model.summary()
  371. ae_model.encoder.summary()
  372. ae_model.channel.summary()
  373. ae_model.decoder.summary()
  374. pass