import json import math import os from datetime import datetime as dt import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score from sklearn.preprocessing import OneHotEncoder from tensorflow.keras import layers, losses from models.data import BinaryTimeDistributedOneHotGenerator from models.layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits import tensorflow_model_optimization as tfmot import graphs class EndToEndAutoencoder(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel, custom_loss_fn=False, quantize=False, alpha=1): """ The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol interference. The autoencoder architecture was heavily influenced by IEEE 8433895. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits :param samples_per_symbol: Number of samples per transmitted symbol :param messages_per_block: Total number of messages in transmission block :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass :param alpha: Alpha value for in loss function """ super(EndToEndAutoencoder, self).__init__() # Labelled M in paper self.cardinality = cardinality self.bits_per_symbol = int(math.log(self.cardinality, 2)) # Labelled n in paper self.samples_per_symbol = samples_per_symbol self.alpha = alpha # Labelled N in paper - conditional +=1 to ensure odd value if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block # Channel Model Layer if isinstance(channel, layers.Layer): self.channel = tf.keras.Sequential([ layers.Flatten(), channel, ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol) ], name="channel_model") else: raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!") # Boolean identifying if bit mapping is to be learnt self.custom_loss_fn = custom_loss_fn # other parameters/metrics self.symbol_error_rate = None self.bit_error_rate = None self.snr = 20 * math.log(0.5 / channel.rx_stddev, 10) # Model Hyper-parameters leaky_relu_alpha = 0 relu_clip_val = 1.0 # Encoding Neural Network self.encoder = tf.keras.Sequential([ layers.Input(shape=(self.messages_per_block, self.cardinality)), layers.TimeDistributed(layers.Dense(2 * self.cardinality)), layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)), layers.TimeDistributed(layers.Dense(2 * self.cardinality)), layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)), layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')), # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)), # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val)) ], name="encoding_model") # Decoding Neural Network self.decoder = tf.keras.Sequential([ layers.Dense(2 * self.cardinality), layers.ReLU(), layers.Dense(2 * self.cardinality), layers.ReLU(), layers.Dense(self.cardinality, activation='softmax') ], name="decoding_model") self.decoder.build((1, self.samples_per_symbol)) def save_end_to_end(self, name): # extract all params and save params = {"fs": self.channel.layers[1].fs, "cardinality": self.cardinality, "samples_per_symbol": self.samples_per_symbol, "messages_per_block": self.messages_per_block, "dispersion_factor": self.channel.layers[1].dispersion_factor, "fiber_length": float(self.channel.layers[1].fiber_length), "fiber_length_stddev": float(self.channel.layers[1].fiber_length_stddev), "lpf_cutoff": self.channel.layers[1].lpf_cutoff, "rx_stddev": self.channel.layers[1].rx_stddev, "sig_avg": self.channel.layers[1].sig_avg, "enob": self.channel.layers[1].enob, "custom_loss_fn": self.custom_loss_fn } if not name: name = dt.utcnow().strftime("%Y%m%d-%H%M%S") dir_str = os.path.join("exports", name) if not os.path.exists(dir_str): os.makedirs(dir_str) with open(os.path.join(dir_str, 'params.json'), 'w') as outfile: json.dump(params, outfile) ################################################################################################################ # This section exports the weights of the encoder formatted using python variable instantiation syntax ################################################################################################################ enc_weights, dec_weights = self.extract_weights() enc_weights = [x.tolist() for x in enc_weights] dec_weights = [x.tolist() for x in dec_weights] enc_w = enc_weights[::2] enc_b = enc_weights[1::2] dec_w = dec_weights[::2] dec_b = dec_weights[1::2] with open(os.path.join(dir_str, 'enc_weights.py'), 'w') as outfile: outfile.write("enc_weights = ") outfile.write(str(enc_w)) outfile.write("\n\nenc_bias = ") outfile.write(str(enc_b)) with open(os.path.join(dir_str, 'dec_weights.py'), 'w') as outfile: outfile.write("dec_weights = ") outfile.write(str(dec_w)) outfile.write("\n\ndec_bias = ") outfile.write(str(dec_b)) ################################################################################################################ self.encoder.save(os.path.join(dir_str, 'encoder')) self.decoder.save(os.path.join(dir_str, 'decoder')) def extract_weights(self): enc_weights = self.encoder.get_weights() dec_weights = self.encoder.get_weights() return enc_weights, dec_weights def encode_stream(self, x): enc_weights, dec_weights = self.extract_weights() for i in range(len(enc_weights) // 2): x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1] if i == len(enc_weights) // 2 - 1: x = tf.keras.activations.sigmoid(x).numpy() else: x = tf.keras.activations.relu(x).numpy() return x def cost(self, y_true, y_pred): symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred) y_bits_true = SymbolsToBits(self.cardinality)(y_true) y_bits_pred = SymbolsToBits(self.cardinality)(y_pred) bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred) return symbol_cost + self.alpha * bit_cost def generate_random_inputs(self, num_of_blocks, return_vals=False): """ A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in consecutively to model ISI. The central message in a block is returned as the label for training. :param return_vals: If true, the raw decimal values of the input sequence will be returned """ cat = [np.arange(self.cardinality)] enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) mid_idx = int((self.messages_per_block - 1) / 2) rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1)) out = enc.fit_transform(rand_int) out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality)) if return_vals: out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1)) return out_val, out_arr, out_arr[:, mid_idx, :] return out_arr, out_arr[:, mid_idx, :] def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3, **kwargs): """ Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm :param train_size: Float less than 1 representing the proportion of the dataset to use for training :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges """ # X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size)) # X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size))) train_data = BinaryTimeDistributedOneHotGenerator( num_of_blocks, cardinality=self.cardinality, blocks=self.messages_per_block) test_data = BinaryTimeDistributedOneHotGenerator( num_of_blocks * .3, cardinality=self.cardinality, blocks=self.messages_per_block) opt = tf.keras.optimizers.Adam(learning_rate=lr) if self.custom_loss_fn: loss_fn = self.cost else: loss_fn = losses.CategoricalCrossentropy() self.compile(optimizer=opt, loss=loss_fn, metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) return self.fit( train_data, epochs=epochs, shuffle=True, validation_data=test_data, **kwargs ) def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True, distance=None): # X_test, y_test = self.generate_random_inputs(int(num_of_blocks)) test_data = BinaryTimeDistributedOneHotGenerator( 1000, cardinality=self.cardinality, blocks=self.messages_per_block) num_of_blocks = int(num_of_blocks / 1000) if num_of_blocks <= 0: num_of_blocks = 1 ber = [] ser = [] for i in range(num_of_blocks): y_out = self.call(test_data.x) y_pred = tf.argmax(y_out, axis=1) y_true = tf.argmax(test_data.y, axis=1) ser.append(1 - accuracy_score(y_true, y_pred)) bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten() bits_true = SymbolsToBits(self.cardinality)(test_data.y).numpy().flatten() ber.append(1 - accuracy_score(bits_true, bits_pred)) test_data.on_epoch_end() print(f"\rTested {i + 1} of {num_of_blocks} blocks", end="") print(f"\rTested all {num_of_blocks} blocks") self.symbol_error_rate = sum(ser) / len(ser) self.bit_error_rate = sum(ber) / len(ber) if length_plot: lengths = np.linspace(0, 70, 50) ber_l = [] for l in lengths: tx_channel = OpticalChannel(fs=self.channel.layers[1].fs, num_of_samples=self.channel.layers[1].num_of_samples, dispersion_factor=self.channel.layers[1].dispersion_factor, fiber_length=l, lpf_cutoff=self.channel.layers[1].lpf_cutoff, rx_stddev=self.channel.layers[1].rx_stddev, sig_avg=self.channel.layers[1].sig_avg, enob=self.channel.layers[1].enob) test_channel = tf.keras.Sequential([ layers.Flatten(), tx_channel, ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol) ], name="test channel (variable length)") X_test_l, y_test_l = self.generate_random_inputs(int(num_of_blocks)) encoded = self.encoder(X_test_l) after_ch = test_channel(encoded) y_out_l = self.decoder(after_ch) y_pred_l = tf.argmax(y_out_l, axis=1) # y_true_l = tf.argmax(y_test_l, axis=1) bits_pred_l = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred_l, self.cardinality)).numpy().flatten() bits_true_l = SymbolsToBits(self.cardinality)(y_test_l).numpy().flatten() bit_error_rate_l = 1 - accuracy_score(bits_true_l, bits_pred_l) ber_l.append(bit_error_rate_l) plt.plot(lengths, ber_l) plt.yscale('log') if plt_show: plt.show() print("SYMBOL ERROR RATE: {:e}".format(self.symbol_error_rate)) print("BIT ERROR RATE: {:e}".format(self.bit_error_rate)) return self.symbol_error_rate, self.bit_error_rate def view_encoder(self): ''' A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for each message/symbol. ''' mid_idx = int((self.messages_per_block - 1) / 2) # Generate inputs for encoder messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality)) idx = 0 for msg in messages: msg[mid_idx, idx] = 1 idx += 1 # Pass input through encoder and select middle messages encoded = self.encoder(messages) enc_messages = encoded[:, mid_idx, :] # Compute subplot grid layout i = 0 while 2 ** i < self.cardinality ** 0.5: i += 1 num_x = int(2 ** i) num_y = int(self.cardinality / num_x) # Plot all symbols fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y)) t = np.arange(self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs sym_idx = 0 for y in range(num_y): for x in range(num_x): axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x') axs[y, x].set_title('Symbol {}'.format(str(sym_idx))) sym_idx += 1 for ax in axs.flat: ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1)) for ax in axs.flat: ax.label_outer() plt.show() pass def view_sample_block(self): ''' Generates a random string of input message and encodes them. In addition to this, the output is passed through digitization layer without any quantization noise for the low pass filtering. ''' # Generate a random block of messages val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True) # Encode and flatten the messages enc = self.encoder(inp) flat_enc = layers.Flatten()(enc) chan_out = self.channel.layers[1](flat_enc) # Instantiate LPF layer lpf = DigitizationLayer(fs=self.channel.layers[1].fs, num_of_samples=self.messages_per_block * self.samples_per_symbol, sig_avg=0) # Apply LPF lpf_out = lpf(flat_enc) a = np.fft.fft(lpf_out.numpy()).flatten() f = np.fft.fftfreq(a.shape[-1]).flatten() plt.plot(f, a) plt.show() # Time axis t = np.arange(self.messages_per_block * self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs # Plot the concatenated symbols before and after LPF plt.figure(figsize=(2 * self.messages_per_block, 6)) for i in range(1, self.messages_per_block): plt.axvline(x=t[i * self.samples_per_symbol], color='black') plt.plot(t, flat_enc.numpy().T, 'x') plt.plot(t, lpf_out.numpy().T) plt.plot(t, chan_out.numpy().flatten()) plt.ylim((0, 1)) plt.xlim((t.min(), t.max())) plt.title(str(val[0, :, 0])) plt.show() def call(self, inputs, training=None, mask=None): tx = self.encoder(inputs) rx = self.channel(tx) outputs = self.decoder(rx) return outputs def load_model(model_name=None): if model_name is None: models = os.listdir("exports") if not models: raise Exception("Unable to find a trained model. Please first train and save a model.") model_name = models[-1] param_file_path = os.path.join("exports", model_name, "params.json") if not os.path.isfile(param_file_path): raise Exception("Invalid File Name/Directory") else: with open(param_file_path, 'r') as param_file: params = json.load(param_file) optical_channel = OpticalChannel(fs=params["fs"], num_of_samples=params["messages_per_block"] * params["samples_per_symbol"], dispersion_factor=params["dispersion_factor"], fiber_length=params["fiber_length"], fiber_length_stddev=params["fiber_length_stddev"], lpf_cutoff=params["lpf_cutoff"], rx_stddev=params["rx_stddev"], sig_avg=params["sig_avg"], enob=params["enob"]) ae_model = EndToEndAutoencoder(cardinality=params["cardinality"], samples_per_symbol=params["samples_per_symbol"], messages_per_block=params["messages_per_block"], channel=optical_channel, custom_loss_fn=params["custom_loss_fn"]) ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_name, "encoder")) ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_name, "decoder")) return ae_model, params def run_tests(distance=50): params = { "fs": 336e9, "cardinality": 64, "samples_per_symbol": 48, "messages_per_block": 9, "dispersion_factor": (-21.7 * 1e-24), "fiber_length": 50, "fiber_length_stddev": 1, "lpf_cutoff": 32e9, "rx_stddev": 0.01, "sig_avg": 0.5, "enob": 6, "custom_loss_fn": True } force_training = True model_save_name = f'{params["fiber_length"]}km-{params["cardinality"]}' # "50km-64" # "20210401-145416" param_file_path = os.path.join("exports", model_save_name, "params.json") if os.path.isfile(param_file_path) and not force_training: print("Importing model {}".format(model_save_name)) with open(param_file_path, 'r') as file: params = json.load(file) optical_channel = OpticalChannel( fs=params["fs"], num_of_samples=params["messages_per_block"] * params["samples_per_symbol"], dispersion_factor=params["dispersion_factor"], fiber_length=params["fiber_length"], fiber_length_stddev=params["fiber_length_stddev"], lpf_cutoff=params["lpf_cutoff"], rx_stddev=params["rx_stddev"], sig_avg=params["sig_avg"], enob=params["enob"], ) ae_model = EndToEndAutoencoder( cardinality=params["cardinality"], samples_per_symbol=params["samples_per_symbol"], messages_per_block=params["messages_per_block"], channel=optical_channel, custom_loss_fn=params["custom_loss_fn"], alpha=5, ) checkpoint_name = f'/tmp/checkpoint/normal_{params["fiber_length"]}km' model_checkpoint_callback0 = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_name, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only=True ) early_stop = tf.keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=1e-2, patience=3, verbose=0, mode='auto', baseline=None, restore_best_weights=True ) # model_checkpoint_callback1 = tf.keras.callbacks.ModelCheckpoint( # filepath='/tmp/checkpoint/quantised', # save_weights_only=True, # monitor='val_accuracy', # mode='max', # save_best_only=True # ) # if os.path.isfile(param_file_path) and not force_training: # ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder")) # ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder")) # print("Loaded existing model from " + model_save_name) # else: if not os.path.isfile(checkpoint_name + '.index'): history = ae_model.train(num_of_blocks=1e3, epochs=30, callbacks=[model_checkpoint_callback0, early_stop]) graphs.show_train_history(history, f"Autoencoder training at {params['fiber_length']}km") ae_model.save_end_to_end(model_save_name) ae_model.load_weights(checkpoint_name) ser, ber = ae_model.test(num_of_blocks=3e6) data = [(params["fiber_length"], ser, ber)] for l in np.linspace(params["fiber_length"] - 2.5, params["fiber_length"] + 2.5, 6): optical_channel = OpticalChannel( fs=params["fs"], num_of_samples=params["messages_per_block"] * params["samples_per_symbol"], dispersion_factor=params["dispersion_factor"], fiber_length=l, fiber_length_stddev=params["fiber_length_stddev"], lpf_cutoff=params["lpf_cutoff"], rx_stddev=params["rx_stddev"], sig_avg=params["sig_avg"], enob=params["enob"], ) ae_model = EndToEndAutoencoder( cardinality=params["cardinality"], samples_per_symbol=params["samples_per_symbol"], messages_per_block=params["messages_per_block"], channel=optical_channel, custom_loss_fn=params["custom_loss_fn"], alpha=5, ) ae_model.load_weights(checkpoint_name) print(f"Testing {l}km") ser, ber = ae_model.test(num_of_blocks=3e6) data.append((l, ser, ber)) return data if __name__ == '__main__': params = { "fs": 336e9, "cardinality": 64, "samples_per_symbol": 48, "messages_per_block": 9, "dispersion_factor": (-21.7 * 1e-24), "fiber_length": 20, "fiber_length_stddev": 1, "lpf_cutoff": 32e9, "rx_stddev": 0.13, "sig_avg": 0.5, "enob": 6, "custom_loss_fn": True } optical_channel = OpticalChannel( fs=params["fs"], num_of_samples=params["messages_per_block"] * params["samples_per_symbol"], dispersion_factor=params["dispersion_factor"], fiber_length=params["fiber_length"], fiber_length_stddev=params["fiber_length_stddev"], lpf_cutoff=params["lpf_cutoff"], rx_stddev=params["rx_stddev"], sig_avg=params["sig_avg"], enob=params["enob"], ) print(optical_channel.compute_snr()) if __name__ == 'asd': data0 = run_tests(90) # data1 = run_tests(70) # data2 = run_tests(80) # print('Results 60: ', data0) # print('Results 70: ', data1) print('Results 90: ', data0) # ae_model.test(num_of_blocks=3e6) # ae_model.load_weights('/tmp/checkpoint/normal') # # quantize_model = tfmot.quantization.keras.quantize_model # ae_model.decoder = quantize_model(ae_model.decoder) # # # ae_model.load_weights('/tmp/checkpoint/quantised') # # history = ae_model.train(num_of_blocks=1e3, epochs=20, callbacks=[model_checkpoint_callback1]) # graphs.show_train_history(history, f"Autoencoder quantised finetune at {params['fiber_length']}km") # SYMBOL ERROR RATE: 2.039667e-03 # 2.358000e-03 # BIT ERROR RATE: 4.646000e-04 # 6.916000e-04 # SYMBOL ERROR RATE: 4.146667e-04 # BIT ERROR RATE: 1.642667e-04 # ae_model.save_end_to_end("50km-q3+") # ae_model.test(num_of_blocks=3e6) # Fibre, SER, BER # 50, 2.233333e-05, 5.000000e-06 # 60, 6.556667e-04, 1.343333e-04 # 75, 1.570333e-03, 3.144667e-04 ## 80, 8.061667e-03, 1.612333e-03 # 85, 7.811333e-03, 1.601600e-03 # 90, 1.121933e-02, 2.255200e-03 ## 90, 1.266433e-02, 2.767467e-03 # 64 cardinality # 50, 5.488000e-03, 1.089000e-03 pass