end_to_end.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. import json
  2. import math
  3. import os
  4. from datetime import datetime as dt
  5. import tensorflow as tf
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from sklearn.metrics import accuracy_score
  9. from sklearn.preprocessing import OneHotEncoder
  10. from tensorflow.keras import layers, losses
  11. from models.data import BinaryTimeDistributedOneHotGenerator
  12. from models.layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
  13. import tensorflow_model_optimization as tfmot
  14. import graphs
  15. class EndToEndAutoencoder(tf.keras.Model):
  16. def __init__(self,
  17. cardinality,
  18. samples_per_symbol,
  19. messages_per_block,
  20. channel,
  21. custom_loss_fn=False,
  22. quantize=False,
  23. alpha=1):
  24. """
  25. The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
  26. of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
  27. interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
  28. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
  29. :param samples_per_symbol: Number of samples per transmitted symbol
  30. :param messages_per_block: Total number of messages in transmission block
  31. :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
  32. :param alpha: Alpha value for in loss function
  33. """
  34. super(EndToEndAutoencoder, self).__init__()
  35. # Labelled M in paper
  36. self.cardinality = cardinality
  37. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  38. # Labelled n in paper
  39. self.samples_per_symbol = samples_per_symbol
  40. self.alpha = alpha
  41. # Labelled N in paper - conditional +=1 to ensure odd value
  42. if messages_per_block % 2 == 0:
  43. messages_per_block += 1
  44. self.messages_per_block = messages_per_block
  45. # Channel Model Layer
  46. if isinstance(channel, layers.Layer):
  47. self.channel = tf.keras.Sequential([
  48. layers.Flatten(),
  49. channel,
  50. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  51. ], name="channel_model")
  52. else:
  53. raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
  54. # Boolean identifying if bit mapping is to be learnt
  55. self.custom_loss_fn = custom_loss_fn
  56. # other parameters/metrics
  57. self.symbol_error_rate = None
  58. self.bit_error_rate = None
  59. self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10)
  60. # Model Hyper-parameters
  61. leaky_relu_alpha = 0
  62. relu_clip_val = 1.0
  63. # Encoding Neural Network
  64. self.encoder = tf.keras.Sequential([
  65. layers.Input(shape=(self.messages_per_block, self.cardinality)),
  66. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  67. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  68. layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
  69. layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
  70. layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
  71. # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
  72. # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
  73. ], name="encoding_model")
  74. # Decoding Neural Network
  75. self.decoder = tf.keras.Sequential([
  76. layers.Dense(2 * self.cardinality),
  77. layers.ReLU(),
  78. layers.Dense(2 * self.cardinality),
  79. layers.ReLU(),
  80. layers.Dense(self.cardinality, activation='softmax')
  81. ], name="decoding_model")
  82. self.decoder.build((1, self.samples_per_symbol))
  83. def save_end_to_end(self, name):
  84. # extract all params and save
  85. params = {"fs": self.channel.layers[1].fs,
  86. "cardinality": self.cardinality,
  87. "samples_per_symbol": self.samples_per_symbol,
  88. "messages_per_block": self.messages_per_block,
  89. "dispersion_factor": self.channel.layers[1].dispersion_factor,
  90. "fiber_length": float(self.channel.layers[1].fiber_length),
  91. "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev),
  92. "lpf_cutoff": self.channel.layers[1].lpf_cutoff,
  93. "rx_stddev": self.channel.layers[1].rx_stddev,
  94. "sig_avg": self.channel.layers[1].sig_avg,
  95. "enob": self.channel.layers[1].enob,
  96. "custom_loss_fn": self.custom_loss_fn
  97. }
  98. if not name:
  99. name = dt.utcnow().strftime("%Y%m%d-%H%M%S")
  100. dir_str = os.path.join("exports", name)
  101. if not os.path.exists(dir_str):
  102. os.makedirs(dir_str)
  103. with open(os.path.join(dir_str, 'params.json'), 'w') as outfile:
  104. json.dump(params, outfile)
  105. ################################################################################################################
  106. # This section exports the weights of the encoder formatted using python variable instantiation syntax
  107. ################################################################################################################
  108. enc_weights, dec_weights = self.extract_weights()
  109. enc_weights = [x.tolist() for x in enc_weights]
  110. dec_weights = [x.tolist() for x in dec_weights]
  111. enc_w = enc_weights[::2]
  112. enc_b = enc_weights[1::2]
  113. dec_w = dec_weights[::2]
  114. dec_b = dec_weights[1::2]
  115. with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile:
  116. outfile.write("enc_weights = ")
  117. outfile.write(str(enc_w))
  118. outfile.write("\n\nenc_bias = ")
  119. outfile.write(str(enc_b))
  120. with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile:
  121. outfile.write("dec_weights = ")
  122. outfile.write(str(dec_w))
  123. outfile.write("\n\ndec_bias = ")
  124. outfile.write(str(dec_b))
  125. ################################################################################################################
  126. self.encoder.save(os.path.join(dir_str, 'encoder'))
  127. self.decoder.save(os.path.join(dir_str, 'decoder'))
  128. def extract_weights(self):
  129. enc_weights = self.encoder.get_weights()
  130. dec_weights = self.encoder.get_weights()
  131. return enc_weights, dec_weights
  132. def encode_stream(self, x):
  133. enc_weights, dec_weights = self.extract_weights()
  134. for i in range(len(enc_weights) // 2):
  135. x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
  136. if i == len(enc_weights) // 2 - 1:
  137. x = tf.keras.activations.sigmoid(x).numpy()
  138. else:
  139. x = tf.keras.activations.relu(x).numpy()
  140. return x
  141. def cost(self, y_true, y_pred):
  142. symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
  143. y_bits_true = SymbolsToBits(self.cardinality)(y_true)
  144. y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
  145. bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
  146. return symbol_cost + self.alpha * bit_cost
  147. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  148. """
  149. A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
  150. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
  151. consecutively to model ISI. The central message in a block is returned as the label for training.
  152. :param return_vals: If true, the raw decimal values of the input sequence will be returned
  153. """
  154. cat = [np.arange(self.cardinality)]
  155. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  156. mid_idx = int((self.messages_per_block - 1) / 2)
  157. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  158. out = enc.fit_transform(rand_int)
  159. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  160. if return_vals:
  161. out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
  162. return out_val, out_arr, out_arr[:, mid_idx, :]
  163. return out_arr, out_arr[:, mid_idx, :]
  164. def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3, **kwargs):
  165. """
  166. Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
  167. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
  168. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
  169. :param train_size: Float less than 1 representing the proportion of the dataset to use for training
  170. :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
  171. """
  172. # X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  173. # X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  174. train_data = BinaryTimeDistributedOneHotGenerator(
  175. num_of_blocks, cardinality=self.cardinality, blocks=self.messages_per_block)
  176. test_data = BinaryTimeDistributedOneHotGenerator(
  177. num_of_blocks * .3, cardinality=self.cardinality, blocks=self.messages_per_block)
  178. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  179. if self.custom_loss_fn:
  180. loss_fn = self.cost
  181. else:
  182. loss_fn = losses.CategoricalCrossentropy()
  183. self.compile(optimizer=opt,
  184. loss=loss_fn,
  185. metrics=['accuracy'],
  186. loss_weights=None,
  187. weighted_metrics=None,
  188. run_eagerly=False
  189. )
  190. return self.fit(
  191. train_data,
  192. epochs=epochs,
  193. shuffle=True,
  194. validation_data=test_data,
  195. **kwargs
  196. )
  197. def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True, distance=None):
  198. # X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
  199. test_data = BinaryTimeDistributedOneHotGenerator(
  200. 1000, cardinality=self.cardinality, blocks=self.messages_per_block)
  201. num_of_blocks = int(num_of_blocks / 1000)
  202. if num_of_blocks <= 0:
  203. num_of_blocks = 1
  204. ber = []
  205. ser = []
  206. for i in range(num_of_blocks):
  207. y_out = self.call(test_data.x)
  208. y_pred = tf.argmax(y_out, axis=1)
  209. y_true = tf.argmax(test_data.y, axis=1)
  210. ser.append(1 - accuracy_score(y_true, y_pred))
  211. bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
  212. bits_true = SymbolsToBits(self.cardinality)(test_data.y).numpy().flatten()
  213. ber.append(1 - accuracy_score(bits_true, bits_pred))
  214. test_data.on_epoch_end()
  215. print(f"\rTested {i + 1} of {num_of_blocks} blocks", end="")
  216. print(f"\rTested all {num_of_blocks} blocks")
  217. self.symbol_error_rate = sum(ser) / len(ser)
  218. self.bit_error_rate = sum(ber) / len(ber)
  219. if length_plot:
  220. lengths = np.linspace(0, 70, 50)
  221. ber_l = []
  222. for l in lengths:
  223. tx_channel = OpticalChannel(fs=self.channel.layers[1].fs,
  224. num_of_samples=self.channel.layers[1].num_of_samples,
  225. dispersion_factor=self.channel.layers[1].dispersion_factor,
  226. fiber_length=l,
  227. lpf_cutoff=self.channel.layers[1].lpf_cutoff,
  228. rx_stddev=self.channel.layers[1].rx_stddev,
  229. sig_avg=self.channel.layers[1].sig_avg,
  230. enob=self.channel.layers[1].enob)
  231. test_channel = tf.keras.Sequential([
  232. layers.Flatten(),
  233. tx_channel,
  234. ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
  235. ], name="test channel (variable length)")
  236. X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks))
  237. encoded = self.encoder(X_test_l)
  238. after_ch = test_channel(encoded)
  239. y_out_l = self.decoder(after_ch)
  240. y_pred_l = tf.argmax(y_out_l, axis=1)
  241. # y_true_l = tf.argmax(y_test_l, axis=1)
  242. bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten()
  243. bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten()
  244. bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l)
  245. ber_l.append(bit_error_rate_l)
  246. plt.plot(lengths, ber_l)
  247. plt.yscale('log')
  248. if plt_show:
  249. plt.show()
  250. print("SYMBOL ERROR RATE: {:e}".format(self.symbol_error_rate))
  251. print("BIT ERROR RATE: {:e}".format(self.bit_error_rate))
  252. return self.symbol_error_rate, self.bit_error_rate
  253. def view_encoder(self):
  254. '''
  255. A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
  256. each message/symbol.
  257. '''
  258. mid_idx = int((self.messages_per_block - 1) / 2)
  259. # Generate inputs for encoder
  260. messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
  261. idx = 0
  262. for msg in messages:
  263. msg[mid_idx, idx] = 1
  264. idx += 1
  265. # Pass input through encoder and select middle messages
  266. encoded = self.encoder(messages)
  267. enc_messages = encoded[:, mid_idx, :]
  268. # Compute subplot grid layout
  269. i = 0
  270. while 2 ** i < self.cardinality ** 0.5:
  271. i += 1
  272. num_x = int(2 ** i)
  273. num_y = int(self.cardinality / num_x)
  274. # Plot all symbols
  275. fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
  276. t = np.arange(self.samples_per_symbol)
  277. if isinstance(self.channel.layers[1], OpticalChannel):
  278. t = t / self.channel.layers[1].fs
  279. sym_idx = 0
  280. for y in range(num_y):
  281. for x in range(num_x):
  282. axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
  283. axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
  284. sym_idx += 1
  285. for ax in axs.flat:
  286. ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
  287. for ax in axs.flat:
  288. ax.label_outer()
  289. plt.show()
  290. pass
  291. def view_sample_block(self):
  292. '''
  293. Generates a random string of input message and encodes them. In addition to this, the output is passed through
  294. digitization layer without any quantization noise for the low pass filtering.
  295. '''
  296. # Generate a random block of messages
  297. val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
  298. # Encode and flatten the messages
  299. enc = self.encoder(inp)
  300. flat_enc = layers.Flatten()(enc)
  301. chan_out = self.channel.layers[1](flat_enc)
  302. # Instantiate LPF layer
  303. lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
  304. num_of_samples=self.messages_per_block * self.samples_per_symbol,
  305. sig_avg=0)
  306. # Apply LPF
  307. lpf_out = lpf(flat_enc)
  308. a = np.fft.fft(lpf_out.numpy()).flatten()
  309. f = np.fft.fftfreq(a.shape[-1]).flatten()
  310. plt.plot(f, a)
  311. plt.show()
  312. # Time axis
  313. t = np.arange(self.messages_per_block * self.samples_per_symbol)
  314. if isinstance(self.channel.layers[1], OpticalChannel):
  315. t = t / self.channel.layers[1].fs
  316. # Plot the concatenated symbols before and after LPF
  317. plt.figure(figsize=(2 * self.messages_per_block, 6))
  318. for i in range(1, self.messages_per_block):
  319. plt.axvline(x=t[i * self.samples_per_symbol], color='black')
  320. plt.plot(t, flat_enc.numpy().T, 'x')
  321. plt.plot(t, lpf_out.numpy().T)
  322. plt.plot(t, chan_out.numpy().flatten())
  323. plt.ylim((0, 1))
  324. plt.xlim((t.min(), t.max()))
  325. plt.title(str(val[0, :, 0]))
  326. plt.show()
  327. def call(self, inputs, training=None, mask=None):
  328. tx = self.encoder(inputs)
  329. rx = self.channel(tx)
  330. outputs = self.decoder(rx)
  331. return outputs
  332. def load_model(model_name=None):
  333. if model_name is None:
  334. models = os.listdir("exports")
  335. if not models:
  336. raise Exception("Unable to find a trained model. Please first train and save a model.")
  337. model_name = models[-1]
  338. param_file_path = os.path.join("exports", model_name, "params.json")
  339. if not os.path.isfile(param_file_path):
  340. raise Exception("Invalid File Name/Directory")
  341. else:
  342. with open(param_file_path, 'r') as param_file:
  343. params = json.load(param_file)
  344. optical_channel = OpticalChannel(fs=params["fs"],
  345. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  346. dispersion_factor=params["dispersion_factor"],
  347. fiber_length=params["fiber_length"],
  348. fiber_length_stddev=params["fiber_length_stddev"],
  349. lpf_cutoff=params["lpf_cutoff"],
  350. rx_stddev=params["rx_stddev"],
  351. sig_avg=params["sig_avg"],
  352. enob=params["enob"])
  353. ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
  354. samples_per_symbol=params["samples_per_symbol"],
  355. messages_per_block=params["messages_per_block"],
  356. channel=optical_channel,
  357. custom_loss_fn=params["custom_loss_fn"])
  358. ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder"))
  359. ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder"))
  360. return ae_model, params
  361. def run_tests(distance=50):
  362. params = {
  363. "fs": 336e9,
  364. "cardinality": 64,
  365. "samples_per_symbol": 48,
  366. "messages_per_block": 9,
  367. "dispersion_factor": (-21.7 * 1e-24),
  368. "fiber_length": 50,
  369. "fiber_length_stddev": 1,
  370. "lpf_cutoff": 32e9,
  371. "rx_stddev": 0.01,
  372. "sig_avg": 0.5,
  373. "enob": 6,
  374. "custom_loss_fn": True
  375. }
  376. force_training = True
  377. model_save_name = f'{params["fiber_length"]}km-{params["cardinality"]}' # "50km-64" # "20210401-145416"
  378. param_file_path = os.path.join("exports", model_save_name, "params.json")
  379. if os.path.isfile(param_file_path) and not force_training:
  380. print("Importing model {}".format(model_save_name))
  381. with open(param_file_path, 'r') as file:
  382. params = json.load(file)
  383. optical_channel = OpticalChannel(
  384. fs=params["fs"],
  385. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  386. dispersion_factor=params["dispersion_factor"],
  387. fiber_length=params["fiber_length"],
  388. fiber_length_stddev=params["fiber_length_stddev"],
  389. lpf_cutoff=params["lpf_cutoff"],
  390. rx_stddev=params["rx_stddev"],
  391. sig_avg=params["sig_avg"],
  392. enob=params["enob"],
  393. )
  394. ae_model = EndToEndAutoencoder(
  395. cardinality=params["cardinality"],
  396. samples_per_symbol=params["samples_per_symbol"],
  397. messages_per_block=params["messages_per_block"],
  398. channel=optical_channel,
  399. custom_loss_fn=params["custom_loss_fn"],
  400. alpha=5,
  401. )
  402. checkpoint_name = f'/tmp/checkpoint/normal_{params["fiber_length"]}km'
  403. model_checkpoint_callback0 = tf.keras.callbacks.ModelCheckpoint(
  404. filepath=checkpoint_name,
  405. save_weights_only=True,
  406. monitor='val_accuracy',
  407. mode='max',
  408. save_best_only=True
  409. )
  410. early_stop = tf.keras.callbacks.EarlyStopping(
  411. monitor='val_loss', min_delta=1e-2, patience=3, verbose=0,
  412. mode='auto', baseline=None, restore_best_weights=True
  413. )
  414. # model_checkpoint_callback1 = tf.keras.callbacks.ModelCheckpoint(
  415. # filepath='/tmp/checkpoint/quantised',
  416. # save_weights_only=True,
  417. # monitor='val_accuracy',
  418. # mode='max',
  419. # save_best_only=True
  420. # )
  421. # if os.path.isfile(param_file_path) and not force_training:
  422. # ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
  423. # ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
  424. # print("Loaded existing model from " + model_save_name)
  425. # else:
  426. if not os.path.isfile(checkpoint_name + '.index'):
  427. history = ae_model.train(num_of_blocks=1e3, epochs=30, callbacks=[model_checkpoint_callback0, early_stop])
  428. graphs.show_train_history(history, f"Autoencoder training at {params['fiber_length']}km")
  429. ae_model.save_end_to_end(model_save_name)
  430. ae_model.load_weights(checkpoint_name)
  431. ser, ber = ae_model.test(num_of_blocks=3e6)
  432. data = [(params["fiber_length"], ser, ber)]
  433. for l in np.linspace(params["fiber_length"] - 2.5, params["fiber_length"] + 2.5, 6):
  434. optical_channel = OpticalChannel(
  435. fs=params["fs"],
  436. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  437. dispersion_factor=params["dispersion_factor"],
  438. fiber_length=l,
  439. fiber_length_stddev=params["fiber_length_stddev"],
  440. lpf_cutoff=params["lpf_cutoff"],
  441. rx_stddev=params["rx_stddev"],
  442. sig_avg=params["sig_avg"],
  443. enob=params["enob"],
  444. )
  445. ae_model = EndToEndAutoencoder(
  446. cardinality=params["cardinality"],
  447. samples_per_symbol=params["samples_per_symbol"],
  448. messages_per_block=params["messages_per_block"],
  449. channel=optical_channel,
  450. custom_loss_fn=params["custom_loss_fn"],
  451. alpha=5,
  452. )
  453. ae_model.load_weights(checkpoint_name)
  454. print(f"Testing {l}km")
  455. ser, ber = ae_model.test(num_of_blocks=3e6)
  456. data.append((l, ser, ber))
  457. return data
  458. if __name__ == '__main__':
  459. params = {
  460. "fs": 336e9,
  461. "cardinality": 64,
  462. "samples_per_symbol": 48,
  463. "messages_per_block": 9,
  464. "dispersion_factor": (-21.7 * 1e-24),
  465. "fiber_length": 20,
  466. "fiber_length_stddev": 1,
  467. "lpf_cutoff": 32e9,
  468. "rx_stddev": 0.13,
  469. "sig_avg": 0.5,
  470. "enob": 6,
  471. "custom_loss_fn": True
  472. }
  473. optical_channel = OpticalChannel(
  474. fs=params["fs"],
  475. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  476. dispersion_factor=params["dispersion_factor"],
  477. fiber_length=params["fiber_length"],
  478. fiber_length_stddev=params["fiber_length_stddev"],
  479. lpf_cutoff=params["lpf_cutoff"],
  480. rx_stddev=params["rx_stddev"],
  481. sig_avg=params["sig_avg"],
  482. enob=params["enob"],
  483. )
  484. print(optical_channel.compute_snr())
  485. if __name__ == 'asd':
  486. data0 = run_tests(90)
  487. # data1 = run_tests(70)
  488. # data2 = run_tests(80)
  489. # print('Results 60: ', data0)
  490. # print('Results 70: ', data1)
  491. print('Results 90: ', data0)
  492. # ae_model.test(num_of_blocks=3e6)
  493. # ae_model.load_weights('/tmp/checkpoint/normal')
  494. #
  495. # quantize_model = tfmot.quantization.keras.quantize_model
  496. # ae_model.decoder = quantize_model(ae_model.decoder)
  497. #
  498. # # ae_model.load_weights('/tmp/checkpoint/quantised')
  499. #
  500. # history = ae_model.train(num_of_blocks=1e3, epochs=20, callbacks=[model_checkpoint_callback1])
  501. # graphs.show_train_history(history, f"Autoencoder quantised finetune at {params['fiber_length']}km")
  502. # SYMBOL ERROR RATE: 2.039667e-03
  503. # 2.358000e-03
  504. # BIT ERROR RATE: 4.646000e-04
  505. # 6.916000e-04
  506. # SYMBOL ERROR RATE: 4.146667e-04
  507. # BIT ERROR RATE: 1.642667e-04
  508. # ae_model.save_end_to_end("50km-q3+")
  509. # ae_model.test(num_of_blocks=3e6)
  510. # Fibre, SER, BER
  511. # 50, 2.233333e-05, 5.000000e-06
  512. # 60, 6.556667e-04, 1.343333e-04
  513. # 75, 1.570333e-03, 3.144667e-04
  514. ## 80, 8.061667e-03, 1.612333e-03
  515. # 85, 7.811333e-03, 1.601600e-03
  516. # 90, 1.121933e-02, 2.255200e-03
  517. ## 90, 1.266433e-02, 2.767467e-03
  518. # 64 cardinality
  519. # 50, 5.488000e-03, 1.089000e-03
  520. pass