from tensorflow.keras import layers import tensorflow as tf import math import numpy as np import itertools class BitsToSymbols(layers.Layer): def __init__(self, cardinality): super(BitsToSymbols, self).__init__() self.cardinality = cardinality n = int(math.log(self.cardinality, 2)) self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32) def call(self, inputs, **kwargs): idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32) out = tf.one_hot(idx, self.cardinality) return layers.Reshape((9, 32))(out) class SymbolsToBits(layers.Layer): def __init__(self, cardinality): super(SymbolsToBits, self).__init__() n = int(math.log(cardinality, 2)) lst = [list(i) for i in itertools.product([0, 1], repeat=n)] # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32) self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.all_syms) class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ A keras layer that extracts the central message(symbol) in a block. :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.samples_per_symbol = samples_per_symbol self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1): """ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise being applied every time the call function is called. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, sig_avg=0.5, enob=10): """ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this, artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10)) self.noise_layer = layers.GaussianNoise(stddev) freq = np.fft.fftfreq(num_of_samples, d=1/fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.01, sig_avg=0.5, enob=10): """ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of ADC/DAC as well as additive white gaussian noise in optical communication channels. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param sig_avg: Average signal amplitude """ super(OpticalChannel, self).__init__() self.rx_stddev = rx_stddev self.noise_layer = layers.GaussianNoise(self.rx_stddev) self.digitization_layer = DigitizationLayer(fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, sig_avg=sig_avg, enob=enob) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64) self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex64) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out