end_to_end.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. import json
  2. import math
  3. import os
  4. from datetime import datetime as dt
  5. import tensorflow as tf
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from sklearn.metrics import accuracy_score
  9. from sklearn.preprocessing import OneHotEncoder
  10. from tensorflow.keras import layers, losses
  11. from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
  12. class EndToEndAutoencoder(tf.keras.Model):
  13. def __init__(self,
  14. cardinality,
  15. samples_per_symbol,
  16. messages_per_block,
  17. channel,
  18. custom_loss_fn=False):
  19. """
  20. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  21. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  22. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  23. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  24. :param samples_per_symbol: Number of samples per transmitted symbol
  25. :param messages_per_block: Total number of messages in transmission block
  26. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  27. """
  28. super(EndToEndAutoencoder, self).__init__()
  29. # Labelled M in paper
  30. self.cardinality = cardinality
  31. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  32. # Labelled n in paper
  33. self.samples_per_symbol = samples_per_symbol
  34. # Labelled N in paper - conditional +=1 to ensure odd value
  35. if messages_per_block % 2 == 0:
  36. messages_per_block += 1
  37. self.messages_per_block = messages_per_block
  38. # Channel Model Layer
  39. if isinstance(channel, layers.Layer):
  40. self.channel = tf.keras.Sequential([
  41. layers.Flatten(),
  42. channel,
  43. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  44. ], name="channel_model")
  45. else:
  46. raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
  47. # Boolean identifying if bit mapping is to be learnt
  48. self.custom_loss_fn = custom_loss_fn
  49. # other parameters/metrics
  50. self.symbol_error_rate = None
  51. self.bit_error_rate = None
  52. self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
  53. # Model Hyper-parameters
  54. leaky_relu_alpha = 0
  55. relu_clip_val = 1.0
  56. # Encoding Neural Network
  57. self.encoder = tf.keras.Sequential([
  58. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  59. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  60. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  61. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  62. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  63. layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  64. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  65. # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  66. ], name="encoding_model")
  67. # Decoding Neural Network
  68. self.decoder = tf.keras.Sequential([
  69. layers.Dense(2 * self.cardinality),
  70. layers.LeakyReLU(alpha=leaky_relu_alpha),
  71. layers.Dense(2 * self.cardinality),
  72. layers.LeakyReLU(alpha=leaky_relu_alpha),
  73. layers.Dense(self.cardinality, activation='softmax')
  74. ], name="decoding_model")
  75. def save_end_to_end(self):
  76. # extract all params and save
  77. params = {"fs": self.channel.layers[1].fs,
  78. "cardinality": self.cardinality,
  79. "samples_per_symbol": self.samples_per_symbol,
  80. "messages_per_block": self.messages_per_block,
  81. "dispersion_factor": self.channel.layers[1].dispersion_factor,
  82. "fiber_length": float(self.channel.layers[1].fiber_length),
  83. "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
  84. "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
  85. "rx_stddev": self.channel.layers[1].rx_stddev,
  86. "sig_avg": self.channel.layers[1].sig_avg,
  87. "enob": self.channel.layers[1].enob,
  88. "custom_loss_fn": self.custom_loss_fn
  89. }
  90. dir_str = os.path.join("exports", dt.utcnow().strftime("%Y%m%d-%H%M%S"))
  91. if not os.path.exists(dir_str):
  92. os.makedirs(dir_str)
  93. with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
  94. json.dump(params, outfile)
  95. ################################################################################################################
  96. # This section exports the weights of the encoder formatted using python variable instantiation syntax
  97. ################################################################################################################
  98. enc_weights, dec_weights = self.extract_weights()
  99. enc_weights = [x.tolist() for x in enc_weights]
  100. dec_weights = [x.tolist() for x in dec_weights]
  101. enc_w = enc_weights[::2]
  102. enc_b = enc_weights[1::2]
  103. dec_w = dec_weights[::2]
  104. dec_b = dec_weights[1::2]
  105. with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
  106. outfile.write("enc_weights = ")
  107. outfile.write(str(enc_w))
  108. outfile.write("\n\nenc_bias = ")
  109. outfile.write(str(enc_b))
  110. with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
  111. outfile.write("dec_weights = ")
  112. outfile.write(str(dec_w))
  113. outfile.write("\n\ndec_bias = ")
  114. outfile.write(str(dec_b))
  115. ################################################################################################################
  116. self.encoder.save(os.path.join(dir_str, 'encoder'))
  117. self.decoder.save(os.path.join(dir_str, 'decoder'))
  118. def extract_weights(self):
  119. enc_weights = self.encoder.get_weights()
  120. dec_weights = self.encoder.get_weights()
  121. return enc_weights, dec_weights
  122. def encode_stream(self, x):
  123. enc_weights, _ = self.extract_weights()
  124. for i in range(len(enc_weights) // 2):
  125. x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
  126. if i == len(enc_weights) // 2 - 1:
  127. x = tf.keras.activations.sigmoid(x).numpy()
  128. else:
  129. x = tf.keras.activations.relu(x).numpy()
  130. return x
  131. def decode_stream(self, x):
  132. _, dec_weights = self.extract_weights()
  133. for i in range(len(dec_weights) // 2):
  134. x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
  135. if i == len(dec_weights) // 2 - 1:
  136. x = tf.keras.activations.softmax(x).numpy()
  137. else:
  138. x = tf.keras.activations.relu(x).numpy()
  139. return x
  140. def cost(self, y_true, y_pred):
  141. symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
  142. y_bits_true = SymbolsToBits(self.cardinality)(y_true)
  143. y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
  144. bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
  145. a = 1
  146. return symbol_cost + a * bit_cost
  147. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  148. """
  149. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  150. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  151. consecutively to model ISI. The central message in a block is returned as the label for training.
  152. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  153. """
  154. cat = [np.arange(self.cardinality)]
  155. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  156. mid_idx = int((self.messages_per_block - 1) / 2)
  157. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  158. out = enc.fit_transform(rand_int)
  159. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  160. if return_vals:
  161. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  162. return out_val, out_arr, out_arr[:, mid_idx, :]
  163. return out_arr, out_arr[:, mid_idx, :]
  164. def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
  165. """
  166. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  167. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  168. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  169. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  170. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  171. """
  172. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  173. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  174. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  175. if self.custom_loss_fn:
  176. loss_fn = self.cost
  177. else:
  178. loss_fn = losses.CategoricalCrossentropy()
  179. callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
  180. self.compile(optimizer=opt,
  181. loss=loss_fn,
  182. metrics=['accuracy'],
  183. loss_weights=None,
  184. weighted_metrics=None,
  185. run_eagerly=False
  186. )
  187. history = self.fit(x=X_train,
  188. y=y_train,
  189. batch_size=batch_size,
  190. epochs=epochs,
  191. callbacks=[callback],
  192. shuffle=True,
  193. validation_data=(X_test, y_test)
  194. )
  195. if len(history.history['loss']) == epochs:
  196. print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
  197. "Setting a higher epoch number and retraining is recommended")
  198. def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
  199. X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
  200. y_out = self.call(X_test)
  201. y_pred = tf.argmax(y_out, axis=1)
  202. y_true = tf.argmax(y_test, axis=1)
  203. self.symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
  204. bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
  205. bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
  206. self.bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
  207. if (length_plot):
  208. lengths = np.linspace(0, 70, 50)
  209. ber_l = []
  210. for l in lengths:
  211. tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
  212. num_of_samples=self.channel.layers[1].num_of_samples,
  213. dispersion_factor=self.channel.layers[1].dispersion_factor,
  214. fiber_length=l,
  215. lpf_cutoff=self.channel.layers[1].lpf_cutoff,
  216. rx_stddev=self.channel.layers[1].rx_stddev,
  217. sig_avg=self.channel.layers[1].sig_avg,
  218. enob=self.channel.layers[1].enob)
  219. test_channel = tf.keras.Sequential([
  220. layers.Flatten(),
  221. tx_channel,
  222. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  223. ], name="test channel (variable length)")
  224. X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
  225. y_out_l = self.decoder(test_channel(self.encoder(X_test_l)))
  226. y_pred_l = tf.argmax(y_out_l, axis=1)
  227. # y_true_l = tf.argmax(y_test_l, axis=1)
  228. bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
  229. bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
  230. bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
  231. ber_l.append(bit_error_rate_l)
  232. plt.plot(lengths, ber_l)
  233. plt.yscale('log')
  234. if plt_show:
  235. plt.show()
  236. print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
  237. print("BIT ERROR RATE: {}".format(self.bit_error_rate))
  238. pass
  239. def view_encoder(self):
  240. '''
  241. A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
  242. each message/symbol.
  243. '''
  244. mid_idx = int((self.messages_per_block - 1) / 2)
  245. # Generate inputs for encoder
  246. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  247. idx = 0
  248. for msg in messages:
  249. msg[mid_idx, idx] = 1
  250. idx += 1
  251. # Pass input through encoder and select middle messages
  252. encoded = self.encoder(messages)
  253. enc_messages = encoded[:, mid_idx, :]
  254. # Compute subplot grid layout
  255. i = 0
  256. while 2 ** i < self.cardinality ** 0.5:
  257. i += 1
  258. num_x = int(2 ** i)
  259. num_y = int(self.cardinality / num_x)
  260. # Plot all symbols
  261. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  262. t = np.arange(self.samples_per_symbol)
  263. if isinstance(self.channel.layers[1], OpticalChannel):
  264. t = t / self.channel.layers[1].fs
  265. sym_idx = 0
  266. for y in range(num_y):
  267. for x in range(num_x):
  268. axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
  269. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  270. sym_idx += 1
  271. for ax in axs.flat:
  272. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  273. for ax in axs.flat:
  274. ax.label_outer()
  275. plt.show()
  276. pass
  277. def view_sample_block(self):
  278. '''
  279. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  280. digitization layer without any quantization noise for the low pass filtering.
  281. '''
  282. # Generate a random block of messages
  283. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  284. # Encode and flatten the messages
  285. enc = self.encoder(inp)
  286. flat_enc = layers.Flatten()(enc)
  287. chan_out = self.channel.layers[1](flat_enc)
  288. # Instantiate LPF layer
  289. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  290. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  291. sig_avg=0)
  292. # Apply LPF
  293. lpf_out = lpf(flat_enc)
  294. a = np.fft.fft(lpf_out.numpy()).flatten()
  295. f = np.fft.fftfreq(a.shape[-1]).flatten()
  296. plt.plot(f, a)
  297. plt.show()
  298. # Time axis
  299. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  300. if isinstance(self.channel.layers[1], OpticalChannel):
  301. t = t / self.channel.layers[1].fs
  302. # Plot the concatenated symbols before and after LPF
  303. plt.figure(figsize=(2 * self.messages_per_block, 6))
  304. for i in range(1, self.messages_per_block):
  305. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  306. plt.plot(t, flat_enc.numpy().T, 'x')
  307. plt.plot(t, lpf_out.numpy().T)
  308. plt.plot(t, chan_out.numpy().flatten())
  309. plt.ylim((0, 1))
  310. plt.xlim((t.min(), t.max()))
  311. plt.title(str(val[0, :, 0]))
  312. plt.show()
  313. def call(self, inputs, training=None, mask=None):
  314. tx = self.encoder(inputs)
  315. rx = self.channel(tx)
  316. outputs = self.decoder(rx)
  317. return outputs
  318. def load_model(model_name=None):
  319. if model_name is None:
  320. models = os.listdir("exports")
  321. if not models:
  322. raise Exception("Unable to find a trained model. Please first train and save a model.")
  323. model_name = models[-1]
  324. param_file_path = os.path.join("exports", model_name, "params.json")
  325. if not os.path.isfile(param_file_path):
  326. raise Exception("Invalid File Name/Directory")
  327. else:
  328. with open(param_file_path, 'r') as param_file:
  329. params = json.load(param_file)
  330. optical_channel = OpticalChannel(fs=params["fs"],
  331. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  332. dispersion_factor=params["dispersion_factor"],
  333. fiber_length=params["fiber_length"],
  334. fiber_length_stddev=params["fiber_length_stddev"],
  335. lpf_cutoff=params["lpf_cutoff"],
  336. rx_stddev=params["rx_stddev"],
  337. sig_avg=params["sig_avg"],
  338. enob=params["enob"])
  339. ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
  340. samples_per_symbol=params["samples_per_symbol"],
  341. messages_per_block=params["messages_per_block"],
  342. channel=optical_channel,
  343. custom_loss_fn=params["custom_loss_fn"])
  344. ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
  345. ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
  346. return ae_model, params
  347. if __name__ == 'asd':
  348. params = {"fs": 336e9,
  349. "cardinality": 32,
  350. "samples_per_symbol": 32,
  351. "messages_per_block": 9,
  352. "dispersion_factor": (-21.7 * 1e-24),
  353. "fiber_length": 50,
  354. "fiber_length_stddev": 1,
  355. "lpf_cutoff": 32e9,
  356. "rx_stddev": 0.01,
  357. "sig_avg": 0.5,
  358. "enob": 8,
  359. "custom_loss_fn": True
  360. }
  361. lengths = np.linspace(40, 100, 7)
  362. ber = []
  363. for len_ in lengths:
  364. optical_channel = OpticalChannel(fs=params["fs"],
  365. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  366. dispersion_factor=params["dispersion_factor"],
  367. fiber_length=len_,
  368. fiber_length_stddev=params["fiber_length_stddev"],
  369. lpf_cutoff=params["lpf_cutoff"],
  370. rx_stddev=0,
  371. sig_avg=0,
  372. enob=params["enob"])
  373. ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
  374. samples_per_symbol=params["samples_per_symbol"],
  375. messages_per_block=params["messages_per_block"],
  376. channel=optical_channel,
  377. custom_loss_fn=params["custom_loss_fn"])
  378. ae_model.train(num_of_blocks=1e5)
  379. ae_model.test()
  380. ber.append(ae_model.bit_error_rate)
  381. plt.plot(lengths, ber)
  382. plt.title("Bit Error Rate at different trained lengths")
  383. plt.yscale('log')
  384. plt.xlabel("Fiber Length / km")
  385. plt.ylabel("Bit Error Rate")
  386. plt.show()
  387. pass
  388. if __name__ == '__main__':
  389. params = {"fs": 336e9,
  390. "cardinality": 32,
  391. "samples_per_symbol": 32,
  392. "messages_per_block": 9,
  393. "dispersion_factor": (-21.7 * 1e-24),
  394. "fiber_length": 50,
  395. "fiber_length_stddev": 1,
  396. "lpf_cutoff": 32e9,
  397. "rx_stddev": 0.01,
  398. "sig_avg": 0.5,
  399. "enob": 8,
  400. "custom_loss_fn": True
  401. }
  402. force_training = False
  403. model_save_name = "20210317-124015"
  404. param_file_path = os.path.join("exports", model_save_name, "params.json")
  405. if os.path.isfile(param_file_path) and not force_training:
  406. print("Importing model {}".format(model_save_name))
  407. with open(param_file_path, 'r') as file:
  408. params = json.load(file)
  409. optical_channel = OpticalChannel(fs=params["fs"],
  410. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  411. dispersion_factor=params["dispersion_factor"],
  412. fiber_length=params["fiber_length"],
  413. fiber_length_stddev=params["fiber_length_stddev"],
  414. lpf_cutoff=params["lpf_cutoff"],
  415. rx_stddev=params["rx_stddev"],
  416. sig_avg=params["sig_avg"],
  417. enob=params["enob"])
  418. ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
  419. samples_per_symbol=params["samples_per_symbol"],
  420. messages_per_block=params["messages_per_block"],
  421. channel=optical_channel,
  422. custom_loss_fn=params["custom_loss_fn"])
  423. if os.path.isfile(param_file_path) and not force_training:
  424. ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
  425. ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
  426. else:
  427. ae_model.train(num_of_blocks=1e4)
  428. ae_model.save_end_to_end()
  429. ae_model.view_encoder()
  430. ae_model.test()
  431. # cat = [np.arange(32)]
  432. # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  433. #
  434. # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
  435. # inp_oh = enc.fit_transform(inp)
  436. #
  437. # out = ae_model(inp_oh.reshape(1, 9, 32))
  438. #
  439. # a = out.numpy()
  440. #
  441. # plt.plot(a)
  442. # plt.show()
  443. pass