| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- import math
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.preprocessing import OneHotEncoder
- from tensorflow.keras import layers, losses
- from scipy.signal import bessel, lfilter, filtfilt
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
- class ExtractCentralMessage(layers.Layer):
- def __init__(self, messages_per_block, samples_per_symbol):
- """
- A keras layer that extracts the central message(symbol) in a block.
- :param messages_per_block: Total number of messages in transmission block
- :param samples_per_symbol: Number of samples per transmitted symbol
- """
- super(ExtractCentralMessage, self).__init__()
- temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
- i = np.identity(samples_per_symbol)
- begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
- end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
- temp_w[begin:end, :] = i
- self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
- def call(self, inputs, **kwargs):
- return tf.matmul(inputs, self.w)
- class AwgnChannel(layers.Layer):
- def __init__(self, rx_stddev=0.1):
- """
- A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
- being applied every time the call function is called.
- :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
- """
- super(AwgnChannel, self).__init__()
- self.noise_layer = layers.GaussianNoise(rx_stddev)
- def call(self, inputs, **kwargs):
- return self.noise_layer.call(inputs, training=True)
- class DigitizationLayer(layers.Layer):
- def __init__(self,
- fs,
- num_of_samples,
- lpf_cutoff=32e9,
- q_stddev=0.1):
- """
- This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
- artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
- :param fs: Sampling frequency of the simulation in Hz
- :param num_of_samples: Total number of samples in the input
- :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
- :param q_stddev: Standard deviation of quantization noise at ADC/DAC
- """
- super(DigitizationLayer, self).__init__()
- self.noise_layer = layers.GaussianNoise(q_stddev)
- freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
- temp = np.ones(freq.shape)
- for idx, val in np.ndenumerate(freq):
- if np.abs(val) > lpf_cutoff:
- temp[idx] = 0
- self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
- def call(self, inputs, **kwargs):
- complex_in = tf.cast(inputs, dtype=tf.complex64)
- val_f = tf.signal.fft(complex_in)
- filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
- filtered_t = tf.signal.ifft(filtered_f)
- real_t = tf.cast(filtered_t, dtype=tf.float32)
- noisy = self.noise_layer.call(real_t, training=True)
- return noisy
- class OpticalChannel(layers.Layer):
- def __init__(self,
- fs,
- num_of_samples,
- dispersion_factor,
- fiber_length,
- lpf_cutoff=32e9,
- rx_stddev=0.01,
- q_stddev=0.01):
- """
- A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
- ADC/DAC as well as additive white gaussian noise in optical communication channels.
- :param fs: Sampling frequency of the simulation in Hz
- :param num_of_samples: Total number of samples in the input
- :param dispersion_factor: Dispersion factor in s^2/km
- :param fiber_length: Length of fiber to model in km
- :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
- :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
- :param q_stddev: Standard deviation of quantization noise at ADC/DAC
- """
- super(OpticalChannel, self).__init__()
- self.noise_layer = layers.GaussianNoise(rx_stddev)
- self.digitization_layer = DigitizationLayer(fs=fs,
- num_of_samples=num_of_samples,
- lpf_cutoff=lpf_cutoff,
- q_stddev=q_stddev)
- self.flatten_layer = layers.Flatten()
- self.fs = fs
- self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128)
- self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq))
- def call(self, inputs, **kwargs):
- # DAC LPF and noise
- dac_out = self.digitization_layer(inputs)
- # Chromatic Dispersion
- complex_val = tf.cast(dac_out, dtype=tf.complex128)
- val_f = tf.signal.fft(complex_val)
- disp_f = tf.math.multiply(val_f, self.multiplier)
- disp_t = tf.signal.ifft(disp_f)
- # Squared-Law Detection
- pd_out = tf.square(tf.abs(disp_t))
- # Casting back to floatx
- real_val = tf.cast(pd_out, dtype=tf.float32)
- # Adding photo-diode receiver noise
- rx_signal = self.noise_layer.call(real_val, training=True)
- # ADC LPF and noise
- adc_out = self.digitization_layer(rx_signal)
- return adc_out
- class ModulationModel(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- messages_per_block,
- channel):
- """
- The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
- of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
- interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
- :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
- :param samples_per_symbol: Number of samples per transmitted symbol
- :param messages_per_block: Total number of messages in transmission block
- :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
- """
- super(ModulationModel, self).__init__()
- # Labelled M in paper
- self.cardinality = cardinality
- # Labelled n in paper
- self.samples_per_symbol = samples_per_symbol
- # Labelled N in paper
- if messages_per_block % 2 == 0:
- messages_per_block += 1
- self.messages_per_block = messages_per_block
- # Channel Model Layer
- if isinstance(channel, layers.Layer):
- self.channel = tf.keras.Sequential([
- layers.Flatten(),
- channel,
- ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
- ])
- else:
- raise TypeError("Channel must be a subclass of keras.layers.layer!")
- def generate_random_inputs(self, num_of_blocks, return_vals=False):
- """
- A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
- :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
- consecutively to model ISI. The central message in a block is returned as the label for training.
- :param return_vals: If true, the raw decimal values of the input sequence will be returned
- """
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
- #cat = [np.arange(self.cardinality)]
- #enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
- #out = enc.fit_transform(rand_int)
- #for symbol in rand_int:
- out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
- t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
- mid_idx = int((self.messages_per_block - 1) / 2)
- if return_vals:
- out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
- return out_val, out_arr, out_arr[:, mid_idx, :]
- return t_out_arr, out_arr
- def view_encoder(self):
- '''
- A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
- each image.
- '''
- # Generate inputs for encoder
- messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
- mid_idx = int((self.messages_per_block - 1) / 2)
- idx = 0
- for msg in messages:
- msg[mid_idx, idx] = 1
- idx += 1
- # Pass input through encoder and select middle messages
- encoded = self.encoder(messages)
- enc_messages = encoded[:, mid_idx, :]
- # Compute subplot grid layout
- i = 0
- while 2 ** i < self.cardinality ** 0.5:
- i += 1
- num_x = int(2 ** i)
- num_y = int(self.cardinality / num_x)
- # Plot all symbols
- fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
- t = np.arange(self.samples_per_symbol)
- if isinstance(self.channel.layers[1], OpticalChannel):
- t = t / self.channel.layers[1].fs
- sym_idx = 0
- for y in range(num_y):
- for x in range(num_x):
- axs[y, x].plot(t, enc_messages[sym_idx], 'x')
- axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
- sym_idx += 1
- for ax in axs.flat:
- ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
- for ax in axs.flat:
- ax.label_outer()
- plt.show()
- pass
- def view_sample_block(self):
- '''
- Generates a random string of input message and encodes them. In addition to this, the output is passed through
- digitization layer without any quantization noise for the low pass filtering.
- '''
- # Generate a random block of messages
- val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
- # Encode and flatten the messages
- enc = self.encoder(inp)
- flat_enc = layers.Flatten()(enc)
- # Instantiate LPF layer
- lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
- num_of_samples=self.messages_per_block * self.samples_per_symbol,
- q_stddev=0)
- # Apply LPF
- lpf_out = lpf(flat_enc)
- # Time axis
- t = np.arange(self.messages_per_block * self.samples_per_symbol)
- if isinstance(self.channel.layers[1], OpticalChannel):
- t = t / self.channel.layers[1].fs
- # Plot the concatenated symbols before and after LPF
- plt.figure(figsize=(2 * self.messages_per_block, 6))
- for i in range(1, self.messages_per_block):
- plt.axvline(x=t[i * self.samples_per_symbol], color='black')
- plt.plot(t, flat_enc.numpy().T, 'x')
- plt.plot(t, lpf_out.numpy().T)
- plt.ylim((0, 1))
- plt.xlim((t.min(), t.max()))
- plt.title(str(val[0, :, 0]))
- plt.show()
- pass
- def plot(self, inputs, signal):
- t = np.arange(self.messages_per_block * self.samples_per_symbol*3)
- frequency = SAMPLING_FREQUENCY*1e-9
- t = np.divide(t, frequency)
- plt.figure()
- plt.plot(t, signal.flatten()[:self.messages_per_block * self.samples_per_symbol*3], label='Received Signal')
- inputs = np.square(inputs)
- plt.plot(t, inputs.flatten()[:self.messages_per_block * self.samples_per_symbol*3], label='Transmitted Signal')
- plt.xlabel('Time (ns)')
- plt.ylabel('Power')
- plt.title('Effect of Chromatic Dispersion and Noise on Generated Signal')
- plt.legend()
- plt.show()
- def demodulate(self, validation, outputs):
- symbol_rate = SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL
- b, a = bessel(5, 10*(symbol_rate)/(SAMPLING_FREQUENCY/2), btype='low', analog=False)
- outputsfilt = filtfilt(b, a, outputs)
- demodulate = np.sqrt(outputsfilt)
- #modulation_scheme.plot(inputs, demodulate)
- average = np.mean(demodulate.reshape(1000, -1, SAMPLES_PER_SYMBOL), axis=2).flatten()
- validation = validation.flatten()
- decisions = []
- for symbol in average:
- if symbol <=0.5 or np.isnan(symbol):
- decisions.append(0)
- elif symbol >0.5 and symbol<=1.5:
- decisions.append(1)
- elif symbol > 1.5 and symbol <= 2.5:
- decisions.append(2)
- else:
- decisions.append(3)
- decisions = np.array(decisions)
- error = 0
- index = 0
- while index < len(validation):
- if validation[index] != decisions[index]:
- error += 1
- index += 1
- #print("ber = " + str(error/len(validation)))
- return error/len(validation)
- def call(self, inputs, training=None, mask=None):
- tx = self.encoder(inputs)
- rx = self.channel(tx)
- outputs = self.decoder(rx)
- return outputs
- def plot_output_graphs():
- inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
- outputs = optical_channel(inputs).numpy()
- b, a = bessel(5, 10 * (SAMPLING_FREQUENCY/SAMPLES_PER_SYMBOL) / (SAMPLING_FREQUENCY / 2), btype='low', analog=False)
- outputsfilt = filtfilt(b, a, outputs)
- modulation_scheme.plot(inputs, outputs)
- modulation_scheme.plot(inputs, outputsfilt)
- #ber.append(modulation_scheme.demodulate(validation, outputs))
- def plot_fibre_length_vs_ber():
- lengths = np.linspace(5, 50, 200)
- ber =[]
- for length in lengths:
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=length)
- inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
- outputs = optical_channel(inputs).numpy()
- ber.append(modulation_scheme.demodulate(validation, outputs))
- plt.figure()
- plt.semilogy(lengths, ber, label='ber')
- plt.xlabel('Fibre Length (km)')
- plt.ylabel('BER')
- plt.title('Effect of Fibre Length on BER of 4PAM System')
- # Show the major grid lines with dark grey lines
- plt.grid(b=True, which='major', color='#666666', linestyle='-')
- # Show the minor grid lines with very faint and almost transparent grey lines
- plt.minorticks_on()
- plt.grid(b=True, which='minor', color='#999999', linestyle='-', alpha=0.2)
- plt.legend()
- plt.show()
- if __name__ == '__main__':
- SAMPLING_FREQUENCY = 336e9
- CARDINALITY = 4
- SAMPLES_PER_SYMBOL = 64
- MESSAGES_PER_BLOCK = 9
- DISPERSION_FACTOR = -21.7 * 1e-24
- FIBER_LENGTH = 30
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=FIBER_LENGTH)
- modulation_scheme = ModulationModel(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel)
- size = 1000
- #inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
- #outputs = optical_channel(inputs).numpy()
- #decisions = modulation_scheme.demodulate(validation, outputs)
- plot_fibre_length_vs_ber()
- #modulation_scheme.plot(inputs, outputs, size)
- print("done")
- pass
|