end_to_end.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.preprocessing import OneHotEncoder
  6. from tensorflow.keras import layers, losses
  7. from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer
  8. class EndToEndAutoencoder(tf.keras.Model):
  9. def __init__(self,
  10. cardinality,
  11. samples_per_symbol,
  12. messages_per_block,
  13. channel,
  14. recurrent=False):
  15. """
  16. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  17. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  18. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  19. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  20. :param samples_per_symbol: Number of samples per transmitted symbol
  21. :param messages_per_block: Total number of messages in transmission block
  22. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  23. """
  24. super(EndToEndAutoencoder, self).__init__()
  25. # Labelled M in paper
  26. self.cardinality = cardinality
  27. # Labelled n in paper
  28. self.samples_per_symbol = samples_per_symbol
  29. # Labelled N in paper
  30. if messages_per_block % 2 == 0:
  31. messages_per_block += 1
  32. self.messages_per_block = messages_per_block
  33. # Channel Model Layer
  34. if isinstance(channel, layers.Layer):
  35. self.channel = tf.keras.Sequential([
  36. layers.Flatten(),
  37. channel,
  38. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  39. ], name="channel_model")
  40. else:
  41. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  42. self.recurrent = recurrent
  43. if recurrent:
  44. input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality), batch_size=1)
  45. # encoding_layers = [
  46. # layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
  47. # layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
  48. # ]
  49. decoding_layers = [
  50. layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
  51. layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
  52. ]
  53. else:
  54. input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality))
  55. decoding_layers = [
  56. layers.Dense(2 * self.cardinality, activation='relu'),
  57. layers.Dense(2 * self.cardinality, activation='relu')
  58. ]
  59. # Encoding Neural Network
  60. self.encoder = tf.keras.Sequential([
  61. input_layer,
  62. layers.Dense(2 * self.cardinality, activation='relu'),
  63. layers.Dense(2 * self.cardinality, activation='relu'),
  64. layers.Dense(self.samples_per_symbol),
  65. layers.ReLU(max_value=1.0)
  66. ], name="encoding_model")
  67. # Decoding Neural Network
  68. self.decoder = tf.keras.Sequential([
  69. layers.Dense(self.samples_per_symbol, activation='relu'),
  70. *decoding_layers,
  71. layers.Dense(self.cardinality, activation='softmax')
  72. ], name="decoding_model")
  73. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  74. """
  75. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  76. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  77. consecutively to model ISI. The central message in a block is returned as the label for training.
  78. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  79. """
  80. cat = [np.arange(self.cardinality)]
  81. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  82. mid_idx = int((self.messages_per_block - 1) / 2)
  83. if self.recurrent and not return_vals:
  84. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks+self.messages_per_block-1, 1))
  85. rand_enc = enc.fit_transform(rand_int)
  86. out = []
  87. for i in range(num_of_blocks):
  88. out.append(rand_enc[i:i+self.messages_per_block])
  89. out = np.array(out)
  90. return out, out[:, mid_idx, :]
  91. else:
  92. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  93. out = enc.fit_transform(rand_int)
  94. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  95. if return_vals:
  96. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  97. return out_val, out_arr, out_arr[:, mid_idx, :]
  98. return out_arr, out_arr[:, mid_idx, :]
  99. def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  100. """
  101. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  102. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  103. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  104. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  105. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  106. """
  107. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  108. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  109. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  110. self.compile(optimizer=opt,
  111. loss=losses.BinaryCrossentropy(),
  112. metrics=['accuracy'],
  113. loss_weights=None,
  114. weighted_metrics=None,
  115. run_eagerly=False
  116. )
  117. shuffle = True
  118. if self.recurrent and batch_size is None:
  119. # If recurrent layers are present in the model then the training data is considered one at a time without
  120. # shuffling of the data. This preserves order in the data.
  121. batch_size = 1
  122. shuffle = False
  123. self.fit(x=X_train,
  124. y=y_train,
  125. batch_size=batch_size,
  126. epochs=1,
  127. shuffle=shuffle,
  128. validation_data=(X_test, y_test)
  129. )
  130. def view_encoder(self):
  131. '''
  132. A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
  133. each image.
  134. '''
  135. # Generate inputs for encoder
  136. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  137. mid_idx = int((self.messages_per_block-1)/2)
  138. idx = 0
  139. for msg in messages:
  140. msg[mid_idx, idx] = 1
  141. idx += 1
  142. # Pass input through encoder and select middle messages
  143. encoded = self.encoder(messages)
  144. enc_messages = encoded[:, mid_idx, :]
  145. # Compute subplot grid layout
  146. i = 0
  147. while 2**i < self.cardinality**0.5:
  148. i += 1
  149. num_x = int(2**i)
  150. num_y = int(self.cardinality / num_x)
  151. # Plot all symbols
  152. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
  153. t = np.arange(self.samples_per_symbol)
  154. if isinstance(self.channel.layers[1], OpticalChannel):
  155. t = t/self.channel.layers[1].fs
  156. sym_idx = 0
  157. for y in range(num_y):
  158. for x in range(num_x):
  159. axs[y, x].plot(t, enc_messages[sym_idx], 'x')
  160. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  161. sym_idx += 1
  162. for ax in axs.flat:
  163. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  164. for ax in axs.flat:
  165. ax.label_outer()
  166. plt.show()
  167. pass
  168. def view_sample_block(self):
  169. '''
  170. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  171. digitization layer without any quantization noise for the low pass filtering.
  172. '''
  173. # Generate a random block of messages
  174. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  175. # Encode and flatten the messages
  176. enc = self.encoder(inp)
  177. flat_enc = layers.Flatten()(enc)
  178. # Instantiate LPF layer
  179. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  180. num_of_samples=self.messages_per_block*self.samples_per_symbol,
  181. q_stddev=0)
  182. # Apply LPF
  183. lpf_out = lpf(flat_enc)
  184. # Time axis
  185. t = np.arange(self.messages_per_block*self.samples_per_symbol)
  186. if isinstance(self.channel.layers[1], OpticalChannel):
  187. t = t / self.channel.layers[1].fs
  188. # Plot the concatenated symbols before and after LPF
  189. plt.figure(figsize=(2*self.messages_per_block, 6))
  190. for i in range(1, self.messages_per_block):
  191. plt.axvline(x=t[i*self.samples_per_symbol], color='black')
  192. plt.plot(t, flat_enc.numpy().T, 'x')
  193. plt.plot(t, lpf_out.numpy().T)
  194. plt.ylim((0, 1))
  195. plt.xlim((t.min(), t.max()))
  196. plt.title(str(val[0, :, 0]))
  197. plt.show()
  198. pass
  199. def call(self, inputs, training=None, mask=None):
  200. tx = self.encoder(inputs)
  201. rx = self.channel(tx)
  202. outputs = self.decoder(rx)
  203. return outputs
  204. if __name__ == '__main__':
  205. SAMPLING_FREQUENCY = 336e9
  206. CARDINALITY = 32
  207. SAMPLES_PER_SYMBOL = 24
  208. MESSAGES_PER_BLOCK = 9
  209. DISPERSION_FACTOR = -21.7 * 1e-24
  210. FIBER_LENGTH = 50
  211. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  212. num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
  213. dispersion_factor=DISPERSION_FACTOR,
  214. fiber_length=FIBER_LENGTH)
  215. ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  216. samples_per_symbol=SAMPLES_PER_SYMBOL,
  217. messages_per_block=MESSAGES_PER_BLOCK,
  218. channel=optical_channel,
  219. recurrent=True)
  220. ae_model.train(num_of_blocks=1e5)
  221. ae_model.view_encoder()
  222. ae_model.view_sample_block()
  223. ae_model.summary()
  224. pass