|
@@ -0,0 +1,342 @@
|
|
|
|
|
+import math
|
|
|
|
|
+
|
|
|
|
|
+import tensorflow as tf
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
|
+from sklearn.preprocessing import OneHotEncoder
|
|
|
|
|
+from tensorflow.keras import layers, losses
|
|
|
|
|
+import os
|
|
|
|
|
+
|
|
|
|
|
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ExtractCentralMessage(layers.Layer):
|
|
|
|
|
+ def __init__(self, messages_per_block, samples_per_symbol):
|
|
|
|
|
+ """
|
|
|
|
|
+ A keras layer that extracts the central message(symbol) in a block.
|
|
|
|
|
+
|
|
|
|
|
+ :param messages_per_block: Total number of messages in transmission block
|
|
|
|
|
+ :param samples_per_symbol: Number of samples per transmitted symbol
|
|
|
|
|
+ """
|
|
|
|
|
+ super(ExtractCentralMessage, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
|
|
|
|
|
+ i = np.identity(samples_per_symbol)
|
|
|
|
|
+ begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
|
|
|
|
|
+ end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
|
|
|
|
|
+ temp_w[begin:end, :] = i
|
|
|
|
|
+
|
|
|
|
|
+ self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ return tf.matmul(inputs, self.w)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class AwgnChannel(layers.Layer):
|
|
|
|
|
+ def __init__(self, rx_stddev=0.1):
|
|
|
|
|
+ """
|
|
|
|
|
+ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
|
|
|
|
|
+ being applied every time the call function is called.
|
|
|
|
|
+
|
|
|
|
|
+ :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
|
|
|
|
|
+ """
|
|
|
|
|
+ super(AwgnChannel, self).__init__()
|
|
|
|
|
+ self.noise_layer = layers.GaussianNoise(rx_stddev)
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ return self.noise_layer.call(inputs, training=True)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DigitizationLayer(layers.Layer):
|
|
|
|
|
+ def __init__(self,
|
|
|
|
|
+ fs,
|
|
|
|
|
+ num_of_samples,
|
|
|
|
|
+ lpf_cutoff=32e9,
|
|
|
|
|
+ q_stddev=0.1):
|
|
|
|
|
+ """
|
|
|
|
|
+ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
|
|
|
|
|
+ artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
|
|
|
|
|
+
|
|
|
|
|
+ :param fs: Sampling frequency of the simulation in Hz
|
|
|
|
|
+ :param num_of_samples: Total number of samples in the input
|
|
|
|
|
+ :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
|
|
|
|
|
+ :param q_stddev: Standard deviation of quantization noise at ADC/DAC
|
|
|
|
|
+ """
|
|
|
|
|
+ super(DigitizationLayer, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.noise_layer = layers.GaussianNoise(q_stddev)
|
|
|
|
|
+ freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
|
|
|
|
|
+ temp = np.ones(freq.shape)
|
|
|
|
|
+
|
|
|
|
|
+ for idx, val in np.ndenumerate(freq):
|
|
|
|
|
+ if np.abs(val) > lpf_cutoff:
|
|
|
|
|
+ temp[idx] = 0
|
|
|
|
|
+
|
|
|
|
|
+ self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ complex_in = tf.cast(inputs, dtype=tf.complex64)
|
|
|
|
|
+ val_f = tf.signal.fft(complex_in)
|
|
|
|
|
+ filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
|
|
|
|
|
+ filtered_t = tf.signal.ifft(filtered_f)
|
|
|
|
|
+ real_t = tf.cast(filtered_t, dtype=tf.float32)
|
|
|
|
|
+ noisy = self.noise_layer.call(real_t, training=True)
|
|
|
|
|
+ return noisy
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class OpticalChannel(layers.Layer):
|
|
|
|
|
+ def __init__(self,
|
|
|
|
|
+ fs,
|
|
|
|
|
+ num_of_samples,
|
|
|
|
|
+ dispersion_factor,
|
|
|
|
|
+ fiber_length,
|
|
|
|
|
+ lpf_cutoff=32e9,
|
|
|
|
|
+ rx_stddev=0.01,
|
|
|
|
|
+ q_stddev=0.01):
|
|
|
|
|
+ """
|
|
|
|
|
+ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
|
|
|
|
|
+ ADC/DAC as well as additive white gaussian noise in optical communication channels.
|
|
|
|
|
+
|
|
|
|
|
+ :param fs: Sampling frequency of the simulation in Hz
|
|
|
|
|
+ :param num_of_samples: Total number of samples in the input
|
|
|
|
|
+ :param dispersion_factor: Dispersion factor in s^2/km
|
|
|
|
|
+ :param fiber_length: Length of fiber to model in km
|
|
|
|
|
+ :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
|
|
|
|
|
+ :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
|
|
|
|
|
+ :param q_stddev: Standard deviation of quantization noise at ADC/DAC
|
|
|
|
|
+ """
|
|
|
|
|
+ super(OpticalChannel, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ self.noise_layer = layers.GaussianNoise(rx_stddev)
|
|
|
|
|
+ self.digitization_layer = DigitizationLayer(fs=fs,
|
|
|
|
|
+ num_of_samples=num_of_samples,
|
|
|
|
|
+ lpf_cutoff=lpf_cutoff,
|
|
|
|
|
+ q_stddev=q_stddev)
|
|
|
|
|
+ self.flatten_layer = layers.Flatten()
|
|
|
|
|
+
|
|
|
|
|
+ self.fs = fs
|
|
|
|
|
+ self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128)
|
|
|
|
|
+ self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq))
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, **kwargs):
|
|
|
|
|
+ # DAC LPF and noise
|
|
|
|
|
+ dac_out = self.digitization_layer(inputs)
|
|
|
|
|
+
|
|
|
|
|
+ # Chromatic Dispersion
|
|
|
|
|
+ complex_val = tf.cast(dac_out, dtype=tf.complex128)
|
|
|
|
|
+ val_f = tf.signal.fft(complex_val)
|
|
|
|
|
+ disp_f = tf.math.multiply(val_f, self.multiplier)
|
|
|
|
|
+ disp_t = tf.signal.ifft(disp_f)
|
|
|
|
|
+
|
|
|
|
|
+ # Squared-Law Detection
|
|
|
|
|
+ pd_out = tf.square(tf.abs(disp_t))
|
|
|
|
|
+
|
|
|
|
|
+ # Casting back to floatx
|
|
|
|
|
+ real_val = tf.cast(pd_out, dtype=tf.float32)
|
|
|
|
|
+
|
|
|
|
|
+ # Adding photo-diode receiver noise
|
|
|
|
|
+ rx_signal = self.noise_layer.call(real_val, training=True)
|
|
|
|
|
+
|
|
|
|
|
+ # ADC LPF and noise
|
|
|
|
|
+ adc_out = self.digitization_layer(rx_signal)
|
|
|
|
|
+
|
|
|
|
|
+ return adc_out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class ModulationModel(tf.keras.Model):
|
|
|
|
|
+ def __init__(self,
|
|
|
|
|
+ cardinality,
|
|
|
|
|
+ samples_per_symbol,
|
|
|
|
|
+ messages_per_block,
|
|
|
|
|
+ channel):
|
|
|
|
|
+ """
|
|
|
|
|
+ The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
|
|
|
|
|
+ of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
|
|
|
|
|
+ interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
|
|
|
|
|
+
|
|
|
|
|
+ :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
|
|
|
|
|
+ :param samples_per_symbol: Number of samples per transmitted symbol
|
|
|
|
|
+ :param messages_per_block: Total number of messages in transmission block
|
|
|
|
|
+ :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
|
|
|
|
|
+ """
|
|
|
|
|
+ super(ModulationModel, self).__init__()
|
|
|
|
|
+
|
|
|
|
|
+ # Labelled M in paper
|
|
|
|
|
+ self.cardinality = cardinality
|
|
|
|
|
+ # Labelled n in paper
|
|
|
|
|
+ self.samples_per_symbol = samples_per_symbol
|
|
|
|
|
+ # Labelled N in paper
|
|
|
|
|
+ if messages_per_block % 2 == 0:
|
|
|
|
|
+ messages_per_block += 1
|
|
|
|
|
+ self.messages_per_block = messages_per_block
|
|
|
|
|
+ # Channel Model Layer
|
|
|
|
|
+ if isinstance(channel, layers.Layer):
|
|
|
|
|
+ self.channel = tf.keras.Sequential([
|
|
|
|
|
+ layers.Flatten(),
|
|
|
|
|
+ channel,
|
|
|
|
|
+ ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
|
|
|
|
|
+ ])
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise TypeError("Channel must be a subclass of keras.layers.layer!")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def generate_random_inputs(self, num_of_blocks, return_vals=False):
|
|
|
|
|
+ """
|
|
|
|
|
+ A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
|
|
|
|
|
+
|
|
|
|
|
+ :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
|
|
|
|
|
+ consecutively to model ISI. The central message in a block is returned as the label for training.
|
|
|
|
|
+ :param return_vals: If true, the raw decimal values of the input sequence will be returned
|
|
|
|
|
+ """
|
|
|
|
|
+ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
|
|
+
|
|
|
|
|
+ #cat = [np.arange(self.cardinality)]
|
|
|
|
|
+ #enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
|
|
|
+
|
|
|
|
|
+ #out = enc.fit_transform(rand_int)
|
|
|
|
|
+ #for symbol in rand_int:
|
|
|
|
|
+
|
|
|
|
|
+ out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block))
|
|
|
|
|
+ t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1)
|
|
|
|
|
+
|
|
|
|
|
+ mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
+
|
|
|
|
|
+ if return_vals:
|
|
|
|
|
+ out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
|
|
|
|
|
+ return out_val, out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
+
|
|
|
|
|
+ return t_out_arr, out_arr[:, mid_idx]
|
|
|
|
|
+
|
|
|
|
|
+ def view_encoder(self):
|
|
|
|
|
+ '''
|
|
|
|
|
+ A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
|
|
|
|
|
+ each image.
|
|
|
|
|
+ '''
|
|
|
|
|
+ # Generate inputs for encoder
|
|
|
|
|
+ messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
|
|
+
|
|
|
|
|
+ mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
+
|
|
|
|
|
+ idx = 0
|
|
|
|
|
+ for msg in messages:
|
|
|
|
|
+ msg[mid_idx, idx] = 1
|
|
|
|
|
+ idx += 1
|
|
|
|
|
+
|
|
|
|
|
+ # Pass input through encoder and select middle messages
|
|
|
|
|
+ encoded = self.encoder(messages)
|
|
|
|
|
+ enc_messages = encoded[:, mid_idx, :]
|
|
|
|
|
+
|
|
|
|
|
+ # Compute subplot grid layout
|
|
|
|
|
+ i = 0
|
|
|
|
|
+ while 2 ** i < self.cardinality ** 0.5:
|
|
|
|
|
+ i += 1
|
|
|
|
|
+
|
|
|
|
|
+ num_x = int(2 ** i)
|
|
|
|
|
+ num_y = int(self.cardinality / num_x)
|
|
|
|
|
+
|
|
|
|
|
+ # Plot all symbols
|
|
|
|
|
+ fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
|
|
|
|
|
+
|
|
|
|
|
+ t = np.arange(self.samples_per_symbol)
|
|
|
|
|
+ if isinstance(self.channel.layers[1], OpticalChannel):
|
|
|
|
|
+ t = t / self.channel.layers[1].fs
|
|
|
|
|
+
|
|
|
|
|
+ sym_idx = 0
|
|
|
|
|
+ for y in range(num_y):
|
|
|
|
|
+ for x in range(num_x):
|
|
|
|
|
+ axs[y, x].plot(t, enc_messages[sym_idx], 'x')
|
|
|
|
|
+ axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
|
|
|
|
|
+ sym_idx += 1
|
|
|
|
|
+
|
|
|
|
|
+ for ax in axs.flat:
|
|
|
|
|
+ ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
|
|
|
|
|
+
|
|
|
|
|
+ for ax in axs.flat:
|
|
|
|
|
+ ax.label_outer()
|
|
|
|
|
+
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ def view_sample_block(self):
|
|
|
|
|
+ '''
|
|
|
|
|
+ Generates a random string of input message and encodes them. In addition to this, the output is passed through
|
|
|
|
|
+ digitization layer without any quantization noise for the low pass filtering.
|
|
|
|
|
+ '''
|
|
|
|
|
+ # Generate a random block of messages
|
|
|
|
|
+ val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
|
|
|
|
|
+
|
|
|
|
|
+ # Encode and flatten the messages
|
|
|
|
|
+ enc = self.encoder(inp)
|
|
|
|
|
+ flat_enc = layers.Flatten()(enc)
|
|
|
|
|
+
|
|
|
|
|
+ # Instantiate LPF layer
|
|
|
|
|
+ lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
|
|
|
|
|
+ num_of_samples=self.messages_per_block * self.samples_per_symbol,
|
|
|
|
|
+ q_stddev=0)
|
|
|
|
|
+
|
|
|
|
|
+ # Apply LPF
|
|
|
|
|
+ lpf_out = lpf(flat_enc)
|
|
|
|
|
+
|
|
|
|
|
+ # Time axis
|
|
|
|
|
+ t = np.arange(self.messages_per_block * self.samples_per_symbol)
|
|
|
|
|
+ if isinstance(self.channel.layers[1], OpticalChannel):
|
|
|
|
|
+ t = t / self.channel.layers[1].fs
|
|
|
|
|
+
|
|
|
|
|
+ # Plot the concatenated symbols before and after LPF
|
|
|
|
|
+ plt.figure(figsize=(2 * self.messages_per_block, 6))
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(1, self.messages_per_block):
|
|
|
|
|
+ plt.axvline(x=t[i * self.samples_per_symbol], color='black')
|
|
|
|
|
+
|
|
|
|
|
+ plt.plot(t, flat_enc.numpy().T, 'x')
|
|
|
|
|
+ plt.plot(t, lpf_out.numpy().T)
|
|
|
|
|
+ plt.ylim((0, 1))
|
|
|
|
|
+ plt.xlim((t.min(), t.max()))
|
|
|
|
|
+ plt.title(str(val[0, :, 0]))
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ def plot(self, inputs, signal, size):
|
|
|
|
|
+ t = np.arange(self.messages_per_block * self.samples_per_symbol * 3)
|
|
|
|
|
+ plt.figure()
|
|
|
|
|
+ plt.plot(t, signal.flatten()[:self.messages_per_block * self.samples_per_symbol * 3])
|
|
|
|
|
+ plt.plot(t, inputs.flatten()[:self.messages_per_block * self.samples_per_symbol * 3])
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
|
|
+ def call(self, inputs, training=None, mask=None):
|
|
|
|
|
+ tx = self.encoder(inputs)
|
|
|
|
|
+ rx = self.channel(tx)
|
|
|
|
|
+ outputs = self.decoder(rx)
|
|
|
|
|
+ return outputs
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ SAMPLING_FREQUENCY = 336e9
|
|
|
|
|
+ CARDINALITY = 4
|
|
|
|
|
+ SAMPLES_PER_SYMBOL = 24
|
|
|
|
|
+ MESSAGES_PER_BLOCK = 9
|
|
|
|
|
+ DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
|
|
+ FIBER_LENGTH = 50
|
|
|
|
|
+
|
|
|
|
|
+ optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
|
|
+ num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
|
|
+ dispersion_factor=DISPERSION_FACTOR,
|
|
|
|
|
+ fiber_length=FIBER_LENGTH)
|
|
|
|
|
+
|
|
|
|
|
+ #channel_output = optical_channel(input)
|
|
|
|
|
+
|
|
|
|
|
+ modulation_scheme = ModulationModel(cardinality=CARDINALITY,
|
|
|
|
|
+ samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
|
|
+ messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
|
|
+ channel=optical_channel)
|
|
|
|
|
+
|
|
|
|
|
+ size = 1000
|
|
|
|
|
+ inputs, validation = modulation_scheme.generate_random_inputs(num_of_blocks=size)
|
|
|
|
|
+ outputs = optical_channel(inputs).numpy()
|
|
|
|
|
+ modulation_scheme.plot(inputs, outputs, size)
|
|
|
|
|
+ print("done")
|
|
|
|
|
+"""
|
|
|
|
|
+ ae_model.view_encoder()
|
|
|
|
|
+ ae_model.view_sample_block()
|
|
|
|
|
+"""
|
|
|
|
|
+pass
|