import math import keras import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from matplotlib import collections as matcoll from sklearn.preprocessing import OneHotEncoder from keras import layers, losses class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1): """ :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, q_stddev=0.1): """ :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() self.noise_layer = layers.GaussianNoise(q_stddev) freq = np.fft.fftfreq(num_of_samples, d=1/fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.01, q_stddev=0.01): """ :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(OpticalChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) self.digitization_layer = DigitizationLayer(fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, q_stddev=q_stddev) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex128) self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex128) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out class EndToEndAutoencoder(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel): """ :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits :param samples_per_symbol: Number of samples per transmitted symbol :param messages_per_block: Total number of messages in transmission block :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass """ super(EndToEndAutoencoder, self).__init__() # Labelled M in paper self.cardinality = cardinality # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block # Channel Model Layer if isinstance(channel, layers.Layer): self.channel = tf.keras.Sequential([ layers.Flatten(), channel, ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol) ]) else: raise TypeError("Channel must be a subclass of keras.layers.layer!") # Encoding Neural Network self.encoder = tf.keras.Sequential([ layers.Input(shape=(self.messages_per_block, self.cardinality)), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.samples_per_symbol), layers.ReLU(max_value=1.0) ]) # Decoding Neural Network self.decoder = tf.keras.Sequential([ layers.Dense(self.samples_per_symbol, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.cardinality, activation='softmax') ]) def generate_random_inputs(self, num_of_blocks, return_vals=False): """ :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in consecutively to model ISI. The central message in a block is returned as the label for training. :param return_vals: If true, the raw decimal values of the input sequence will be returned """ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1)) cat = [np.arange(self.cardinality)] enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) out = enc.fit_transform(rand_int) out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality)) mid_idx = int((self.messages_per_block-1)/2) if return_vals: out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1)) return out_val, out_arr, out_arr[:, mid_idx, :] return out_arr, out_arr[:, mid_idx, :] def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3): """ :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size. :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm :param train_size: Float less than 1 representing the proportion of the dataset to use for training :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges """ X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size))) opt = keras.optimizers.Adam(learning_rate=lr) self.compile(optimizer=opt, loss=losses.BinaryCrossentropy(), metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) self.fit(x=X_train, y=y_train, batch_size=batch_size, epochs=1, shuffle=True, validation_data=(X_test, y_test) ) def view_encoder(self): # Generate inputs for encoder messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality)) mid_idx = int((self.messages_per_block-1)/2) idx = 0 for msg in messages: msg[mid_idx, idx] = 1 idx += 1 # Pass input through encoder and select middle messages encoded = self.encoder(messages) enc_messages = encoded[:, mid_idx, :] # Compute subplot grid layout i = 0 while 2**i < self.cardinality**0.5: i += 1 num_x = int(2**i) num_y = int(self.cardinality / num_x) # Plot all symbols fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y)) t = np.arange(self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t/self.channel.layers[1].fs sym_idx = 0 for y in range(num_y): for x in range(num_x): axs[y, x].plot(t, enc_messages[sym_idx], 'x') axs[y, x].set_title('Symbol {}'.format(str(sym_idx))) sym_idx += 1 for ax in axs.flat: ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1)) for ax in axs.flat: ax.label_outer() plt.show() pass def view_sample_block(self): # Generate a random block of messages val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True) # Encode and flatten the messages enc = self.encoder(inp) flat_enc = layers.Flatten()(enc) # Instantiate LPF layer lpf = DigitizationLayer(fs=self.channel.layers[1].fs, num_of_samples=self.messages_per_block*self.samples_per_symbol, q_stddev=0) # Apply LPF lpf_out = lpf(flat_enc) # Time axis t = np.arange(self.messages_per_block*self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs # Plot the concatenated symbols before and after LPF plt.figure(figsize=(2*self.messages_per_block, 6)) for i in range(1, self.messages_per_block): plt.axvline(x=t[i*self.samples_per_symbol], color='black') plt.plot(t, flat_enc.numpy().T, 'x') plt.plot(t, lpf_out.numpy().T) plt.ylim((0, 1)) plt.xlim((t.min(), t.max())) plt.title(str(val[0, :, 0])) plt.show() pass def call(self, inputs, training=None, mask=None): tx = self.encoder(inputs) rx = self.channel(tx) outputs = self.decoder(rx) return outputs if __name__ == '__main__': SAMPLING_FREQUENCY = 336e9 CARDINALITY = 32 SAMPLES_PER_SYMBOL = 24 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH) ae_model = EndToEndAutoencoder(cardinality=CARDINALITY, samples_per_symbol=SAMPLES_PER_SYMBOL, messages_per_block=MESSAGES_PER_BLOCK, channel=optical_channel) ae_model.train(num_of_blocks=1e6, batch_size=100) ae_model.view_encoder() ae_model.view_sample_block() pass