import numpy as np import matplotlib.pyplot as plt import scipy from scipy import interpolate import tensorflow as tf from tensorflow.keras import layers, losses import math class ExtractCentralMessage(layers.Layer): def __init__(self, messages_per_block, samples_per_symbol): """ A keras layer that extracts the central message(symbol) in a block. :param messages_per_block: Total number of messages in transmission block :param samples_per_symbol: Number of samples per transmitted symbol """ super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((messages_per_block - 1) / 2)) end = int(samples_per_symbol * ((messages_per_block + 1) / 2)) temp_w[begin:end, :] = i self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs, **kwargs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, rx_stddev=0.1): """ A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise being applied every time the call function is called. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) """ super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) def call(self, inputs, **kwargs): return self.noise_layer.call(inputs, training=True) class DigitizationLayer(layers.Layer): def __init__(self, fs, num_of_samples, lpf_cutoff=32e9, q_stddev=0.1): """ This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this, artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(DigitizationLayer, self).__init__() self.noise_layer = layers.GaussianNoise(q_stddev) freq = np.fft.fftfreq(num_of_samples, d=1 / fs) temp = np.ones(freq.shape) for idx, val in np.ndenumerate(freq): if np.abs(val) > lpf_cutoff: temp[idx] = 0 self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64) def call(self, inputs, **kwargs): complex_in = tf.cast(inputs, dtype=tf.complex64) val_f = tf.signal.fft(complex_in) filtered_f = tf.math.multiply(self.lpf_multiplier, val_f) filtered_t = tf.signal.ifft(filtered_f) real_t = tf.cast(filtered_t, dtype=tf.float32) noisy = self.noise_layer.call(real_t, training=True) return noisy class OpticalChannel(layers.Layer): def __init__(self, fs, num_of_samples, dispersion_factor, fiber_length, lpf_cutoff=32e9, rx_stddev=0.03, q_stddev=0.01): """ A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of ADC/DAC as well as additive white gaussian noise in optical communication channels. :param fs: Sampling frequency of the simulation in Hz :param num_of_samples: Total number of samples in the input :param dispersion_factor: Dispersion factor in s^2/km :param fiber_length: Length of fiber to model in km :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit) :param q_stddev: Standard deviation of quantization noise at ADC/DAC """ super(OpticalChannel, self).__init__() self.noise_layer = layers.GaussianNoise(rx_stddev) self.digitization_layer = DigitizationLayer(fs=fs, num_of_samples=num_of_samples, lpf_cutoff=lpf_cutoff, q_stddev=q_stddev) self.flatten_layer = layers.Flatten() self.fs = fs self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex128) self.multiplier = tf.math.exp(0.5j * dispersion_factor * fiber_length * tf.math.square(2 * math.pi * self.freq)) def call(self, inputs, **kwargs): # DAC LPF and noise dac_out = self.digitization_layer(inputs) # Chromatic Dispersion complex_val = tf.cast(dac_out, dtype=tf.complex128) val_f = tf.signal.fft(complex_val) disp_f = tf.math.multiply(val_f, self.multiplier) disp_t = tf.signal.ifft(disp_f) # Squared-Law Detection pd_out = tf.square(tf.abs(disp_t)) # Casting back to floatx real_val = tf.cast(pd_out, dtype=tf.float32) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(real_val, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out SAMPLING_FREQUENCY = 336e9 CARDINALITY = 4 SAMPLES_PER_SYMBOL = 128 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=80, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH) K = 64 # number of OFDM subcarriers CP = K//4 # length of the cyclic prefix: 25% of the block P = 8 # number of pilot carriers per OFDM block pilotValue = 3+3j # The known value each pilot transmits allCarriers = np.arange(K) # indices of all subcarriers ([0, 1, ... K-1]) pilotCarriers = allCarriers[::K//P] # Pilots is every (K/P)th carrier. # For convenience of channel estimation, let's make the last carriers also be a pilot pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])]) P = P+1 # data carriers are all remaining carriers dataCarriers = np.delete(allCarriers, pilotCarriers) print ("allCarriers: %s" % allCarriers) print ("pilotCarriers: %s" % pilotCarriers) print ("dataCarriers: %s" % dataCarriers) plt.plot(pilotCarriers, np.zeros_like(pilotCarriers), 'bo', label='pilot') plt.plot(dataCarriers, np.zeros_like(dataCarriers), 'ro', label='data') mu = 4 # bits per symbol (i.e. 16QAM) payloadBits_per_OFDM = len(dataCarriers)*mu # number of payload bits per OFDM symbol mapping_table = { (0,0,0,0) : -3-3j, (0,0,0,1) : -3-1j, (0,0,1,0) : -3+3j, (0,0,1,1) : -3+1j, (0,1,0,0) : -1-3j, (0,1,0,1) : -1-1j, (0,1,1,0) : -1+3j, (0,1,1,1) : -1+1j, (1,0,0,0) : 3-3j, (1,0,0,1) : 3-1j, (1,0,1,0) : 3+3j, (1,0,1,1) : 3+1j, (1,1,0,0) : 1-3j, (1,1,0,1) : 1-1j, (1,1,1,0) : 1+3j, (1,1,1,1) : 1+1j } mapping_table_dec = { (0) : -3-3j, (1) : -3-1j, (2) : -3+3j, (3) : -3+1j, (4) : -1-3j, (5) : -1-1j, (6) : -1+3j, (7) : -1+1j, (8) : 3-3j, (9) : 3-1j, (10) : 3+3j, (11) : 3+1j, (12) : 1-3j, (13) : 1-1j, (14) : 1+3j, (15) : 1+1j } for b3 in [0, 1]: for b2 in [0, 1]: for b1 in [0, 1]: for b0 in [0, 1]: B = (b3, b2, b1, b0) Q = mapping_table[B] plt.plot(Q.real, Q.imag, 'bo') plt.text(Q.real, Q.imag+0.2, "".join(str(x) for x in B), ha='center') demapping_table = {v: k for k, v in mapping_table.items()} # Replace with our channel channelResponse = np.array([1, 0, 0.3+0.3j]) # the impulse response of the wireless channel H_exact = np.fft.fft(channelResponse, K) plt.plot(allCarriers, abs(H_exact)) SNRdb = 25 # signal to noise-ratio in dB at the receiver # Here #water filling, gradient decent methods for optimising the symbol mapping, instead of 16 QAM bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM, )) print ("Bits count: ", len(bits)) print ("First 20 bits: ", bits[:20]) print ("Mean of bits (should be around 0.5): ", np.mean(bits)) def SP(bits): return bits.reshape((len(dataCarriers), mu)) bits_SP = SP(bits) print ("First 5 bit groups") print (bits_SP[:5,:]) def generate_random_inputs(num_of_blocks, return_vals=False): """ A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data. :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in consecutively to model ISI. The central message in a block is returned as the label for training. :param return_vals: If true, the raw decimal values of the input sequence will be returned """ rand_int = np.random.randint(16, size=(num_of_blocks, 1)) # cat = [np.arange(self.cardinality)] # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) # out = enc.fit_transform(rand_int) # for symbol in rand_int: #out_arr = np.reshape(rand_int, (num_of_blocks, self.messages_per_block)) #t_out_arr = np.repeat(out_arr, self.samples_per_symbol, axis=1) #mid_idx = int((self.messages_per_block - 1) / 2) #if return_vals: # out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1)) # return out_val, out_arr, out_arr[:, mid_idx, :] return rand_int bits_SP1 = generate_random_inputs(num_of_blocks=1000).astype('uint8') #bits_SP11 = np.unpackbits(bits_SP1, axis=1) def Mapping(bits): return np.array([mapping_table[tuple(b)] for b in bits]) def Mapping_dec(bits): return np.array([mapping_table_dec[tuple(b)] for b in bits]) QAM = Mapping(bits_SP) QAM1 = Mapping_dec(bits_SP1) print("First 5 QAM symbols and bits:") print(bits_SP[:5,:]) print(QAM[:5]) def OFDM_symbol(QAM_payload): symbol = np.zeros(K, dtype=complex) # the overall K subcarriers symbol[pilotCarriers] = pilotValue # allocate the pilot subcarriers symbol[dataCarriers] = QAM_payload # allocate the pilot subcarriers return symbol OFDM_data = OFDM_symbol(QAM) print("Number of OFDM carriers in frequency domain: ", len(OFDM_data)) def IDFT(OFDM_data): return np.fft.ifft(OFDM_data) OFDM_time = IDFT(OFDM_data) print("Number of OFDM samples in time-domain before CP: ", len(OFDM_time)) def addCP(OFDM_time): cp = OFDM_time[-CP:] # take the last CP samples ... return np.hstack([cp, OFDM_time]) # ... and add them to the beginning OFDM_withCP = addCP(OFDM_time) print("Number of OFDM samples in time domain with CP: ", len(OFDM_withCP)) def channel(signal): convolved = np.convolve(signal, channelResponse) signal_power = np.mean(abs(convolved ** 2)) sigma2 = signal_power * 10 ** (-SNRdb / 10) # calculate noise power based on signal power and SNR print("RX Signal power: %.4f. Noise power: %.4f" % (signal_power, sigma2)) # Generate complex noise with given variance noise = np.sqrt(sigma2 / 2) * (np.random.randn(*convolved.shape) + 1j * np.random.randn(*convolved.shape)) return convolved + noise OFDM_TX = OFDM_withCP OFDM_RX = channel(OFDM_TX) #OFDM_RX1 = optical_channel(OFDM_TX).numpy plt.figure(figsize=(8, 2)) plt.plot(abs(OFDM_TX), label='TX signal') plt.plot(abs(OFDM_RX), label='RX signal') plt.legend(fontsize=10) plt.xlabel('Time') plt.ylabel('$|x(t)|$') plt.grid(True) def removeCP(signal): return signal[CP:(CP+K)] OFDM_RX_noCP = removeCP(OFDM_RX) def DFT(OFDM_RX): return np.fft.fft(OFDM_RX) OFDM_demod = DFT(OFDM_RX_noCP) def channelEstimate(OFDM_demod): pilots = OFDM_demod[pilotCarriers] # extract the pilot values from the RX signal Hest_at_pilots = pilots / pilotValue # divide by the transmitted pilot values # Perform interpolation between the pilot carriers to get an estimate # of the channel in the data carriers. Here, we interpolate absolute value and phase # separately Hest_abs = interpolate.interp1d(pilotCarriers, abs(Hest_at_pilots), kind='linear')(allCarriers) Hest_phase = interpolate.interp1d(pilotCarriers, np.angle(Hest_at_pilots), kind='linear')(allCarriers) Hest = Hest_abs * np.exp(1j * Hest_phase) plt.plot(allCarriers, abs(H_exact), label='Correct Channel') plt.stem(pilotCarriers, abs(Hest_at_pilots), label='Pilot estimates') plt.plot(allCarriers, abs(Hest), label='Estimated channel via interpolation') plt.grid(True) plt.xlabel('Carrier index') plt.ylabel('$|H(f)|$') plt.legend(fontsize=10) plt.ylim(0, 2) return Hest Hest = channelEstimate(OFDM_demod) def equalize(OFDM_demod, Hest): return OFDM_demod / Hest equalized_Hest = equalize(OFDM_demod, Hest) def get_payload(equalized): return equalized[dataCarriers] QAM_est = get_payload(equalized_Hest) plt.plot(QAM_est.real, QAM_est.imag, 'bo'); def Demapping(QAM): # array of possible constellation points constellation = np.array([x for x in demapping_table.keys()]) # calculate distance of each RX point to each possible point dists = abs(QAM.reshape((-1, 1)) - constellation.reshape((1, -1))) # for each element in QAM, choose the index in constellation # that belongs to the nearest constellation point const_index = dists.argmin(axis=1) # get back the real constellation point hardDecision = constellation[const_index] # transform the constellation point into the bit groups return np.vstack([demapping_table[C] for C in hardDecision]), hardDecision PS_est, hardDecision = Demapping(QAM_est) for qam, hard in zip(QAM_est, hardDecision): plt.plot([qam.real, hard.real], [qam.imag, hard.imag], 'b-o'); plt.plot(hardDecision.real, hardDecision.imag, 'ro') def PS(bits): return bits.reshape((-1,)) bits_est = PS(PS_est) print("Obtained Bit error rate: ", np.sum(abs(bits-bits_est))/len(bits))