""" Custom Keras Layers for general use """ import itertools from tensorflow.keras import layers import tensorflow as tf import numpy as np class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs): """ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise being applied every time the call function is called. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__(**kwargs) if noise_dB is not None: # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0))) rx_stddev = 10 ** (noise_dB / 10.0) self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class ScaleAndOffset(layers.Layer): """ Scales and offsets a tensor """ def __init__(self, scale=1, offset=0, **kwargs): super(ScaleAndOffset, self).__init__(**kwargs) self.offset = offset self.scale = scale def call(self, inputs, **kwargs): return inputs * self.scale + self.offset class BitsToSymbol(layers.Layer): def __init__(self, cardinality, **kwargs): super().__init__(**kwargs) self.cardinality = cardinality n = int(np.log(self.cardinality, 2)) self.powers = tf.convert_to_tensor( np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32 ) def call(self, inputs, **kwargs): idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32) return tf.one_hot(idx, self.cardinality) class SymbolToBits(layers.Layer): def __init__(self, cardinality, **kwargs): super().__init__(**kwargs) n = int(np.log(cardinality, 2)) l = [list(i) for i in itertools.product([0, 1], repeat=n)] self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32)) def call(self, inputs, **kwargs): return tf.matmul(self.all_syms, inputs) class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ A keras layer that extracts the central message(symbol) in a block. :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.samples_per_symbol = samples_per_symbol self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, sig_avg=0.5, enob=10): """ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this, artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10)) self.noise_layer = layers.GaussianNoise(stddev) freq = np.fft.fftfreq(num_of_samples, d=1 / fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.01, sig_avg=0.5, enob=10): """ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of ADC/DAC as well as additive white gaussian noise in optical communication channels. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param sig_avg: Average signal amplitude """ super(OpticalChannel, self).__init__() self.rx_stddev = rx_stddev self.noise_layer = layers.GaussianNoise(self.rx_stddev) self.digitization_layer = DigitizationLayer( fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, sig_avg=sig_avg, enob=enob ) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor( np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64) self.multiplier = tf.math.exp( 0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex64) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out