end_to_end.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.metrics import accuracy_score
  6. from sklearn.preprocessing import OneHotEncoder
  7. from tensorflow.keras import layers, losses
  8. from tensorflow.keras import backend as K
  9. from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
  10. import itertools
  11. class EndToEndAutoencoder(tf.keras.Model):
  12. def __init__(self,
  13. cardinality,
  14. samples_per_symbol,
  15. messages_per_block,
  16. channel,
  17. bit_mapping=False):
  18. """
  19. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  20. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  21. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  22. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  23. :param samples_per_symbol: Number of samples per transmitted symbol
  24. :param messages_per_block: Total number of messages in transmission block
  25. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  26. """
  27. super(EndToEndAutoencoder, self).__init__()
  28. # Labelled M in paper
  29. self.cardinality = cardinality
  30. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  31. # Labelled n in paper
  32. self.samples_per_symbol = samples_per_symbol
  33. # Labelled N in paper
  34. if messages_per_block % 2 == 0:
  35. messages_per_block += 1
  36. self.messages_per_block = messages_per_block
  37. # Channel Model Layer
  38. if isinstance(channel, layers.Layer):
  39. self.channel = tf.keras.Sequential([
  40. layers.Flatten(),
  41. channel,
  42. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  43. ], name="channel_model")
  44. else:
  45. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  46. # Boolean identifying if bit mapping is to be learnt
  47. self.bit_mapping = bit_mapping
  48. # other parameters/metrics
  49. self.symbol_error_rate = None
  50. self.bit_error_rate = None
  51. self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
  52. # Model Hyper-parameters
  53. leaky_relu_alpha = 0
  54. relu_clip_val = 1.0
  55. # Layer configuration for the case when bit mapping is to be learnt
  56. if self.bit_mapping:
  57. encoding_layers = [
  58. layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
  59. BitsToSymbols(self.cardinality),
  60. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  61. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  62. # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  63. # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  64. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  65. layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  66. layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  67. ]
  68. decoding_layers = [
  69. layers.Dense(2 * self.cardinality),
  70. layers.LeakyReLU(alpha=leaky_relu_alpha),
  71. # layers.Dense(2 * self.cardinality),
  72. # layers.LeakyReLU(alpha=0.01),
  73. layers.Dense(self.bits_per_symbol, activation='sigmoid')
  74. ]
  75. # layer configuration for the case when only symbol mapping is to be learnt
  76. else:
  77. encoding_layers = [
  78. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  79. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  80. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  81. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  82. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  83. layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  84. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  85. # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  86. ]
  87. decoding_layers = [
  88. layers.Dense(2 * self.cardinality),
  89. layers.LeakyReLU(alpha=leaky_relu_alpha),
  90. layers.Dense(2 * self.cardinality),
  91. layers.LeakyReLU(alpha=leaky_relu_alpha),
  92. layers.Dense(self.cardinality, activation='softmax')
  93. ]
  94. # Encoding Neural Network
  95. self.encoder = tf.keras.Sequential([
  96. *encoding_layers
  97. ], name="encoding_model")
  98. # Decoding Neural Network
  99. self.decoder = tf.keras.Sequential([
  100. *decoding_layers
  101. ], name="decoding_model")
  102. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  103. """
  104. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  105. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  106. consecutively to model ISI. The central message in a block is returned as the label for training.
  107. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  108. """
  109. cat = [np.arange(self.cardinality)]
  110. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  111. mid_idx = int((self.messages_per_block - 1) / 2)
  112. if self.bit_mapping:
  113. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  114. out = rand_int
  115. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  116. if return_vals:
  117. return out_arr, out_arr, out_arr[:, mid_idx, :]
  118. else:
  119. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  120. out = enc.fit_transform(rand_int)
  121. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  122. if return_vals:
  123. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  124. return out_val, out_arr, out_arr[:, mid_idx, :]
  125. return out_arr, out_arr[:, mid_idx, :]
  126. def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
  127. """
  128. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  129. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  130. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  131. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  132. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  133. """
  134. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  135. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  136. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  137. # TODO: Investigate different optimizers (with different learning rates and other parameters)
  138. # SGD
  139. # RMSprop
  140. # Adam
  141. # Adadelta
  142. # Adagrad
  143. # Adamax
  144. # Nadam
  145. # Ftrl
  146. if self.bit_mapping:
  147. loss_fn = losses.BinaryCrossentropy()
  148. else:
  149. loss_fn = losses.CategoricalCrossentropy()
  150. self.compile(optimizer=opt,
  151. loss=loss_fn,
  152. metrics=['accuracy'],
  153. loss_weights=None,
  154. weighted_metrics=None,
  155. run_eagerly=False
  156. )
  157. history = self.fit(x=X_train,
  158. y=y_train,
  159. batch_size=batch_size,
  160. epochs=epochs,
  161. shuffle=True,
  162. validation_data=(X_test, y_test)
  163. )
  164. def test(self, num_of_blocks=1e4, length_plot=False):
  165. X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
  166. y_out = self.call(X_test)
  167. y_pred = tf.argmax(y_out, axis=1)
  168. y_true = tf.argmax(y_test, axis=1)
  169. self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
  170. bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
  171. bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
  172. self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
  173. if (length_plot):
  174. lengths = np.linspace(0, 70, 50)
  175. ber_l = []
  176. for l in lengths:
  177. tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
  178. num_of_samples=self.channel.layers[1].num_of_samples,
  179. dispersion_factor=self.channel.layers[1].dispersion_factor,
  180. fiber_length=l,
  181. lpf_cutoff=self.channel.layers[1].lpf_cutoff,
  182. rx_stddev=self.channel.layers[1].rx_stddev,
  183. sig_avg=self.channel.layers[1].sig_avg,
  184. enob=self.channel.layers[1].enob)
  185. test_channel = tf.keras.Sequential([
  186. layers.Flatten(),
  187. tx_channel,
  188. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  189. ], name="test channel (variable length)")
  190. X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
  191. y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
  192. y_pred_l = tf.argmax(y_out_l, axis=1)
  193. # y_true_l = tf.argmax(y_test_l, axis=1)
  194. bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
  195. bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
  196. bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
  197. ber_l.append(bit_error_rate_l)
  198. plt.plot(lengths, ber_l)
  199. plt.yscale('log')
  200. plt.show()
  201. print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
  202. print("BIT ERROR RATE: {}".format(self.bit_error_rate))
  203. pass
  204. def view_encoder(self):
  205. '''
  206. A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
  207. each message/symbol.
  208. '''
  209. mid_idx = int((self.messages_per_block - 1) / 2)
  210. if self.bit_mapping:
  211. messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
  212. lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
  213. idx = 0
  214. for msg in messages:
  215. msg[mid_idx] = lst[idx]
  216. idx += 1
  217. else:
  218. # Generate inputs for encoder
  219. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  220. idx = 0
  221. for msg in messages:
  222. msg[mid_idx, idx] = 1
  223. idx += 1
  224. # Pass input through encoder and select middle messages
  225. encoded = self.encoder(messages)
  226. enc_messages = encoded[:, mid_idx, :]
  227. # Compute subplot grid layout
  228. i = 0
  229. while 2 ** i < self.cardinality ** 0.5:
  230. i += 1
  231. num_x = int(2 ** i)
  232. num_y = int(self.cardinality / num_x)
  233. # Plot all symbols
  234. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  235. t = np.arange(self.samples_per_symbol)
  236. if isinstance(self.channel.layers[1], OpticalChannel):
  237. t = t / self.channel.layers[1].fs
  238. sym_idx = 0
  239. for y in range(num_y):
  240. for x in range(num_x):
  241. axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
  242. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  243. sym_idx += 1
  244. for ax in axs.flat:
  245. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  246. for ax in axs.flat:
  247. ax.label_outer()
  248. plt.show()
  249. pass
  250. def view_sample_block(self):
  251. '''
  252. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  253. digitization layer without any quantization noise for the low pass filtering.
  254. '''
  255. # Generate a random block of messages
  256. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  257. # Encode and flatten the messages
  258. enc = self.encoder(inp)
  259. flat_enc = layers.Flatten()(enc)
  260. chan_out = self.channel.layers[1](flat_enc)
  261. # Instantiate LPF layer
  262. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  263. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  264. sig_avg=0)
  265. # Apply LPF
  266. lpf_out = lpf(flat_enc)
  267. # Time axis
  268. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  269. if isinstance(self.channel.layers[1], OpticalChannel):
  270. t = t / self.channel.layers[1].fs
  271. # Plot the concatenated symbols before and after LPF
  272. plt.figure(figsize=(2 * self.messages_per_block, 6))
  273. for i in range(1, self.messages_per_block):
  274. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  275. plt.plot(t, flat_enc.numpy().T, 'x')
  276. plt.plot(t, lpf_out.numpy().T)
  277. plt.plot(t, chan_out.numpy().flatten())
  278. plt.ylim((0, 1))
  279. plt.xlim((t.min(), t.max()))
  280. plt.title(str(val[0, :, 0]))
  281. plt.show()
  282. pass
  283. def call(self, inputs, training=None, mask=None):
  284. tx = self.encoder(inputs)
  285. rx = self.channel(tx)
  286. outputs = self.decoder(rx)
  287. return outputs
  288. SAMPLING_FREQUENCY = 336e9
  289. CARDINALITY = 32
  290. SAMPLES_PER_SYMBOL = 32
  291. MESSAGES_PER_BLOCK = 9
  292. DISPERSION_FACTOR = -21.7 * 1e-24
  293. FIBER_LENGTH = 0
  294. if __name__ == '__main__':
  295. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  296. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  297. dispersion_factor=DISPERSION_FACTOR,
  298. fiber_length=FIBER_LENGTH)
  299. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  300. samples_per_symbol=SAMPLES_PER_SYMBOL,
  301. messages_per_block=MESSAGES_PER_BLOCK,
  302. channel=optical_channel,
  303. bit_mapping=False)
  304. ae_model.train(num_of_blocks=1e5, epochs=5)
  305. ae_model.test()
  306. ae_model.view_encoder()
  307. ae_model.view_sample_block()
  308. # ae_model.summary()
  309. ae_model.encoder.summary()
  310. ae_model.channel.summary()
  311. ae_model.decoder.summary()
  312. pass