import math import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import OneHotEncoder from tensorflow.keras import layers, losses from scipy.signal import bessel, lfilter, filtfilt import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ A keras layer that extracts the central message(symbol) in a block. :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1): """ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise being applied every time the call function is called. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, q_stddev=0.1): """ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this, artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() self.noise_layer = layers.GaussianNoise(q_stddev) freq = np.fft.fftfreq(num_of_samples, d=1 / fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.03, q_stddev=0.01): """ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of ADC/DAC as well as additive white gaussian noise in optical communication channels. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(OpticalChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) self.digitization_layer = DigitizationLayer(fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, q_stddev=q_stddev) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128) self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex128) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out class ModulationModel(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel): """ The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol interference. The autoencoder architecture was heavily influenced by IEEE 8433895. :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits :param samples_per_symbol: Number of samples per transmitted symbol :param messages_per_block: Total number of messages in transmission block :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass """ super(ModulationModel, self).__init__() # Labelled M in paper self.cardinality = cardinality # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block # Channel Model Layer if isinstance(channel, layers.Layer): self.channel = tf.keras.Sequential([ layers.Flatten(), channel, ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol) ]) else: raise TypeError("Channel must be a subclass of keras.layers.layer!") def generate_random_inputs(self, num_of_blocks, return_vals=False): """ A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in consecutively to model ISI. The central message in a block is returned as the label for training. :param return_vals: If true, the raw decimal values of the input sequence will be returned """ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1)) #cat = [np.arange(self.cardinality)] #enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) #out = enc.fit_transform(rand_int) #for symbol in rand_int: out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block)) t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1) mid_idx = int((self.messages_per_block - 1) / 2) if return_vals: out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1)) return out_val, out_arr, out_arr[:, mid_idx, :] return t_out_arr, out_arr def view_encoder(self): ''' A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for each image. ''' # Generate inputs for encoder messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality)) mid_idx = int((self.messages_per_block - 1) / 2) idx = 0 for msg in messages: msg[mid_idx, idx] = 1 idx += 1 # Pass input through encoder and select middle messages encoded = self.encoder(messages) enc_messages = encoded[:, mid_idx, :] # Compute subplot grid layout i = 0 while 2 ** i < self.cardinality ** 0.5: i += 1 num_x = int(2 ** i) num_y = int(self.cardinality / num_x) # Plot all symbols fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y)) t = np.arange(self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs sym_idx = 0 for y in range(num_y): for x in range(num_x): axs[y, x].plot(t, enc_messages[sym_idx], 'x') axs[y, x].set_title('Symbol {}'.format(str(sym_idx))) sym_idx += 1 for ax in axs.flat: ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1)) for ax in axs.flat: ax.label_outer() plt.show() pass def view_sample_block(self): ''' Generates a random string of input message and encodes them. In addition to this, the output is passed through digitization layer without any quantization noise for the low pass filtering. ''' # Generate a random block of messages val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True) # Encode and flatten the messages enc = self.encoder(inp) flat_enc = layers.Flatten()(enc) # Instantiate LPF layer lpf = DigitizationLayer(fs=self.channel.layers[1].fs, num_of_samples=self.messages_per_block * self.samples_per_symbol, q_stddev=0) # Apply LPF lpf_out = lpf(flat_enc) # Time axis t = np.arange(self.messages_per_block * self.samples_per_symbol) if isinstance(self.channel.layers[1], OpticalChannel): t = t / self.channel.layers[1].fs # Plot the concatenated symbols before and after LPF plt.figure(figsize=(2 * self.messages_per_block, 6)) for i in range(1, self.messages_per_block): plt.axvline(x=t[i * self.samples_per_symbol], color='black') plt.plot(t, flat_enc.numpy().T, 'x') plt.plot(t, lpf_out.numpy().T) plt.ylim((0, 1)) plt.xlim((t.min(), t.max())) plt.title(str(val[0, :, 0])) plt.show() pass def plot(self, inputs, signal): t = np.arange(self.messages_per_block * self.samples_per_symbol*1) frequency = SAMPLING_FREQUENCY*1e-9 t = np.divide(t, frequency) plt.figure() plt.plot(t, signal.flatten()[:self.messages_per_block * self.samples_per_symbol*1], label='Received Signal') inputs = np.square(inputs) plt.plot(t, inputs.flatten()[:self.messages_per_block * self.samples_per_symbol*1], label='Transmitted Signal') plt.xlabel('Time (ns)') plt.ylabel('Power') plt.title('Effect of Chromatic Dispersion and Noise on Generated Signal') plt.legend() plt.show() def demodulate(self, validation, outputs): symbol_rate = SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL b, a = bessel(5, 8*(symbol_rate)/(SAMPLING_FREQUENCY/2), btype='low', analog=False) outputsfilt = filtfilt(b, a, outputs) demodulate = np.sqrt(outputsfilt) #modulation_scheme.plot(inputs, demodulate) average = np.mean(demodulate.reshape(1000, -1, SAMPLES_PER_SYMBOL), axis=2).flatten() validation = validation.flatten() decisions = [] for symbol in average: if symbol <=0.5 or np.isnan(symbol): decisions.append(0) elif symbol >0.5 and symbol<=1.5: decisions.append(1) elif symbol > 1.5 and symbol <= 2.5: decisions.append(2) else: decisions.append(3) decisions = np.array(decisions) error = 0 index = 0 while index < len(validation): if validation[index] != decisions[index]: error += 1 index += 1 #print("ber = " + str(error/len(validation))) return error/len(validation) def call(self, inputs, training=None, mask=None): tx = self.encoder(inputs) rx = self.channel(tx) outputs = self.decoder(rx) return outputs def plot_output_graphs(): inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size) outputs = optical_channel(inputs).numpy() b, a = bessel(3, 8 * (SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL) / (SAMPLING_FREQUENCY / 2), btype='low', analog=False) outputsfilt = filtfilt(b, a, outputs) input = np.sqrt(inputs) noisy = np.sqrt(outputs) demodulate = np.sqrt(outputsfilt) modulation_scheme.plot(input, noisy) modulation_scheme.plot(input, demodulate) #ber.append(modulation_scheme.demodulate(validation, outputs)) def plot_fibre_length_vs_ber(): lengths = np.linspace(5, 70, 1000) ber =[] for length in lengths: optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=length) inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size) outputs = optical_channel(inputs).numpy() ber.append(modulation_scheme.demodulate(validation, outputs)) plt.figure() plt.semilogy(lengths, ber, label='ber') plt.xlabel('Fibre Length (km)') plt.ylabel('BER') plt.title('Effect of Fibre Length on BER of 4PAM System') # Show the major grid lines with dark grey lines plt.grid(b=True, which='major', color='#666666', linestyle='-') # Show the minor grid lines with very faint and almost transparent grey lines plt.minorticks_on() plt.grid(b=True, which='minor', color='#999999', linestyle='-', alpha=0.2) plt.legend() plt.show() if __name__ == '__main__': SAMPLING_FREQUENCY = 336e9 CARDINALITY = 4 SAMPLES_PER_SYMBOL = 128 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH) modulation_scheme = ModulationModel(cardinality=CARDINALITY, samples_per_symbol=SAMPLES_PER_SYMBOL, messages_per_block=MESSAGES_PER_BLOCK, channel=optical_channel) size = 1000 #inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size) #outputs = optical_channel(inputs).numpy() #decisions = modulation_scheme.demodulate(validation, outputs) #plot_output_graphs() plot_fibre_length_vs_ber() #modulation_scheme.plot(inputs, outputs, size) print("done") pass