import matplotlib.pyplot as plt import defs import numpy as np import math from numpy.fft import fft, fftfreq, ifft from commpy.filters import rrcosfilter, rcosfilter, rectfilter class OpticalChannel(defs.Channel): def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect', sqrt_out=False, show_graphs=False, **kwargs): """ :param noise_level: Noise level in dB :param dispersion: dispersion coefficient is ps^2/km :param symbol_rate: Symbol rate of modulated signal in Hz :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz :param length: fibre length in km :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos'] :param sqrt_out: Take the root of the out to compensate for photodiode detection :param show_graphs: if graphs should be displayed or not Optical Channel class constructor """ super().__init__(**kwargs) self.noise = 10 ** (noise_level / 10) self.dispersion = dispersion * 1e-24 # Converting from ps^2/km to s^2/km self.symbol_rate = symbol_rate self.symbol_period = 1 / self.symbol_rate self.sample_rate = sample_rate self.sample_period = 1 / self.sample_rate self.length = length self.pulse_shape = pulse_shape.strip().lower() self.sqrt_out = sqrt_out self.show_graphs = show_graphs def __get_time_domain(self, symbol_vals): samples_per_symbol = int(self.sample_rate / self.symbol_rate) samples = int(symbol_vals.shape[0] * samples_per_symbol) symbol_impulse = np.zeros(samples) # TODO: Implement Frequency/Phase Modulation for i in range(symbol_vals.shape[0]): symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0] if self.pulse_shape == 'rrcos': self.filter_samples = 5 * samples_per_symbol self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate) elif self.pulse_shape == 'rcos': self.filter_samples = 5 * samples_per_symbol self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate) else: self.filter_samples = samples_per_symbol self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate) val_t = np.convolve(symbol_impulse, self.h_filter) t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0]) return t, val_t def __time_to_frequency(self, values): val_f = fft(values) f = fftfreq(values.shape[-1])*self.sample_rate return f, val_f def __frequency_to_time(self, values): val_t = ifft(values) t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size) return t, val_t def __apply_dispersion(self, values): # Obtain fft f, val_f = self.__time_to_frequency(values) if self.show_graphs: plt.plot(f, val_f) plt.title('frequency domain (pre-distortion)') plt.show() # Apply distortion dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2)) if self.show_graphs: plt.plot(f, dist_val_f) plt.title('frequency domain (post-distortion)') plt.show() # Inverse fft t, val_t = self.__frequency_to_time(dist_val_f) return t, val_t def __photodiode_detection(self, values): t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size) val_t = np.power(np.absolute(values), 2) return t, val_t def forward(self, values): # Converting APF representation to time-series t, val_t = self.__get_time_domain(values) if self.show_graphs: plt.plot(t, val_t) plt.title('time domain (raw)') plt.show() # Adding AWGN val_t += np.random.normal(0, 1, val_t.shape) * self.noise if self.show_graphs: plt.plot(t, val_t) plt.title('time domain (AWGN)') plt.show() # Applying chromatic dispersion t, val_t = self.__apply_dispersion(val_t) if self.show_graphs: plt.plot(t, val_t) plt.title('time domain (post-distortion)') plt.show() # Photodiode Detection t, val_t = self.__photodiode_detection(val_t) # Symbol Decisions idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2), self.symbol_period/self.sample_period, dtype='int16') t_descision = self.sample_period * idx if self.show_graphs: plt.plot(t, val_t) plt.title('time domain (post-detection)') plt.show() plt.plot(t, val_t) for xc in t_descision: plt.axvline(x=xc, color='r') plt.title('time domain (post-detection with decision times)') plt.show() # TODO: Implement Frequency/Phase Modulation out = np.zeros(values.shape) out[:, 0] = val_t[idx] out[:, 1] = values[:, 1] out[:, 2] = values[:, 2] if self.sqrt_out: out[:, 0] = np.sqrt(out[:, 0]) return out if __name__ == '__main__': # Simple OOK modulation num_of_symbols = 100 symbol_vals = np.zeros((num_of_symbols, 3)) symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0]) symbol_vals[:, 2] = 40e9 channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9, sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True) v = channel.forward(symbol_vals) rx = (v > 0.5).astype(int) tru = np.sum(rx == symbol_vals[:, 0].astype(int)) print("Accuracy: {}".format(tru/num_of_symbols))