import math import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import OneHotEncoder from tensorflow.keras import layers, losses class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ A keras layer that extracts the central message(symbol) in a block. :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1): """ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise being applied every time the call function is called. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, q_stddev=0.1): """ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this, artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() self.noise_layer = layers.GaussianNoise(q_stddev) freq = np.fft.fftfreq(num_of_samples, d=1/fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.01, q_stddev=0.01): """ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of ADC/DAC as well as additive white gaussian noise in optical communication channels. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(OpticalChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) self.digitization_layer = DigitizationLayer(fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, q_stddev=q_stddev) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128) self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex128) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out class EndToEndAutoencoder(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel): """ The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol interference. The autoencoder architecture was heavily influenced by IEEE 8433895. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits :param samples_per_symbol: Number of samples per transmitted symbol :param messages_per_block: Total number of messages in transmission block :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass """ super(EndToEndAutoencoder, self).__init__() # Labelled M in paper self.cardinality = cardinality # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block # Channel Model Layer if isinstance(channel, layers.Layer): self.channel = tf.keras.Sequential([ layers.Flatten(), channel, ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol) ]) else: raise TypeError("Channel must be a subclass of keras.layers.layer!") # Encoding Neural Network self.encoder = tf.keras.Sequential([ layers.Input(shape=(self.messages_per_block, self.cardinality)), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.samples_per_symbol), layers.ReLU(max_value=1.0) ]) # Decoding Neural Network self.decoder = tf.keras.Sequential([ layers.Dense(self.samples_per_symbol, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.cardinality, activation='softmax') ]) def generate_random_inputs(self, num_of_blocks, return_vals=False): """ A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in consecutively to model ISI. The central message in a block is returned as the label for training. :param return_vals: If true, the raw decimal values of the input sequence will be returned """ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1)) cat = [np.arange(self.cardinality)] enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) out = enc.fit_transform(rand_int) out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality)) mid_idx = int((self.messages_per_block-1)/2) if return_vals: out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1)) return out_val, out_arr, out_arr[:, mid_idx, :] return out_arr, out_arr[:, mid_idx, :] def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3): """ Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here. :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm :param train_size: Float less than 1 representing the proportion of the dataset to use for training :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges """ X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size))) opt = tf.keras.optimizers.Adam(learning_rate=lr) self.compile(optimizer=opt, loss=losses.BinaryCrossentropy(), metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) self.fit(x=X_train, y=y_train, batch_size=batch_size, epochs=1, shuffle=True, validation_data=(X_test, y_test) ) def view_encoder(self): ''' A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for each image. ''' # Generate inputs for encoder messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality)) mid_idx = int((self.messages_per_block-1)/2) idx = 0 for msg in messages: msg[mid_idx, idx] = 1 idx += 1 # Pass input through encoder and select middle messages encoded = self.encoder(messages) enc_messages = encoded[:, mid_idx, :] # Compute subplot grid layout i = 0 while 2**i < self.cardinality**0.5: i += 1 num_x = int(2**i) num_y = int(self.cardinality / num_x) # Plot all symbols fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y)) t = np.arange(self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t/self.channel.layers[1].fs sym_idx = 0 for y in range(num_y): for x in range(num_x): axs[y, x].plot(t, enc_messages[sym_idx], 'x') axs[y, x].set_title('Symbol {}'.format(str(sym_idx))) sym_idx += 1 for ax in axs.flat: ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1)) for ax in axs.flat: ax.label_outer() plt.show() pass def view_sample_block(self): ''' Generates a random string of input message and encodes them. In addition to this, the output is passed through digitization layer without any quantization noise for the low pass filtering. ''' # Generate a random block of messages val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True) # Encode and flatten the messages enc = self.encoder(inp) flat_enc = layers.Flatten()(enc) # Instantiate LPF layer lpf = DigitizationLayer(fs=self.channel.layers[1].fs, num_of_samples=self.messages_per_block*self.samples_per_symbol, q_stddev=0) # Apply LPF lpf_out = lpf(flat_enc) # Time axis t = np.arange(self.messages_per_block*self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs # Plot the concatenated symbols before and after LPF plt.figure(figsize=(2*self.messages_per_block, 6)) for i in range(1, self.messages_per_block): plt.axvline(x=t[i*self.samples_per_symbol], color='black') plt.plot(t, flat_enc.numpy().T, 'x') plt.plot(t, lpf_out.numpy().T) plt.ylim((0, 1)) plt.xlim((t.min(), t.max())) plt.title(str(val[0, :, 0])) plt.show() pass def call(self, inputs, training=None, mask=None): tx = self.encoder(inputs) rx = self.channel(tx) outputs = self.decoder(rx) return outputs if __name__ == '__main__': SAMPLING_FREQUENCY = 336e9 CARDINALITY = 32 SAMPLES_PER_SYMBOL = 24 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH) ae_model = EndToEndAutoencoder(cardinality=CARDINALITY, samples_per_symbol=SAMPLES_PER_SYMBOL, messages_per_block=MESSAGES_PER_BLOCK, channel=optical_channel) ae_model.train(num_of_blocks=1e6, batch_size=100) ae_model.view_encoder() ae_model.view_sample_block() pass