import tensorflow as tf from tensorflow.keras import layers, losses from models.custom_layers import ExtractCentralMessage, OpticalChannel from models.end_to_end import EndToEndAutoencoder from models.custom_layers import BitsToSymbols, SymbolsToBits import numpy as np import math from matplotlib import pyplot as plt class BitMappingModel(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel): super(BitMappingModel, self).__init__() # Labelled M in paper self.cardinality = cardinality self.bits_per_symbol = int(math.log(self.cardinality, 2)) # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality, samples_per_symbol=self.samples_per_symbol, messages_per_block=self.messages_per_block, channel=channel, bit_mapping=False) self.bit_error_rate = [] self.symbol_error_rate = [] def call(self, inputs, training=None, mask=None): x1 = BitsToSymbols(self.cardinality)(inputs) x2 = self.e2e_model(x1) out = SymbolsToBits(self.cardinality)(x2) return out def generate_random_inputs(self, num_of_blocks, return_vals=False): """ """ mid_idx = int((self.messages_per_block - 1) / 2) rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1)) out = rand_int out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol)) if return_vals: return out_arr, out_arr, out_arr[:, mid_idx, :] return out_arr, out_arr[:, mid_idx, :] def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3): X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size))) X_train = tf.convert_to_tensor(X_train, dtype=tf.float32) X_test = tf.convert_to_tensor(X_test, dtype=tf.float32) opt = tf.keras.optimizers.Adam(learning_rate=lr) self.compile(optimizer=opt, loss=losses.BinaryCrossentropy(), metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) self.fit(x=X_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=True, validation_data=(X_test, y_test) ) def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3): for _ in range(iters): self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs) self.e2e_model.test() self.symbol_error_rate.append(self.e2e_model.symbol_error_rate) self.bit_error_rate.append(self.e2e_model.bit_error_rate) X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size))) X_train = tf.convert_to_tensor(X_train, dtype=tf.float32) X_test = tf.convert_to_tensor(X_test, dtype=tf.float32) opt = tf.keras.optimizers.Adam(learning_rate=lr) self.compile(optimizer=opt, loss=losses.BinaryCrossentropy(), metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) self.fit(x=X_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=True, validation_data=(X_test, y_test) ) self.e2e_model.test() self.symbol_error_rate.append(self.e2e_model.symbol_error_rate) self.bit_error_rate.append(self.e2e_model.bit_error_rate) SAMPLING_FREQUENCY = 336e9 CARDINALITY = 32 SAMPLES_PER_SYMBOL = 32 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 if __name__ == '__main__': distances = [0, 10, 20, 30, 40, 50, 60] ser = [] ber = [] baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9) bit_rate = math.log(CARDINALITY, 2) * baud_rate snr = None for d in distances: optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=d) model = BitMappingModel(cardinality=CARDINALITY, samples_per_symbol=SAMPLES_PER_SYMBOL, messages_per_block=MESSAGES_PER_BLOCK, channel=optical_channel) if snr is None: snr = model.e2e_model.snr elif snr != model.e2e_model.snr: print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!") # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr)) model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5) ber.append(model.bit_error_rate[-1]) ser.append(model.symbol_error_rate[-1]) plt.plot(model.bit_error_rate, label='BER') plt.plot(model.symbol_error_rate, label='SER') plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr)) plt.legend() plt.show() # model.summary() plt.plot(ber, label='BER') plt.plot(ser, label='SER') plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr)) plt.legend(ber)