optical_channel.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import matplotlib.pyplot as plt
  2. import defs
  3. import numpy as np
  4. import math
  5. from scipy.fft import fft, ifft
  6. class OpticalChannel(defs.Channel):
  7. def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, show_graphs=False, **kwargs):
  8. """
  9. :param noise_level: Noise level in dB
  10. :param dispersion: dispersion coefficient is ps^2/km
  11. :param symbol_rate: Symbol rate of modulated signal in Hz
  12. :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
  13. :param length: fibre length in km
  14. :param show_graphs: if graphs should be displayed or not
  15. Optical Channel class constructor
  16. """
  17. super().__init__(**kwargs)
  18. self.noise = 10 ** (noise_level / 10)
  19. self.dispersion = dispersion # * 1e-24 # Converting from ps^2/km to s^2/km
  20. self.symbol_rate = symbol_rate
  21. self.symbol_period = 1 / self.symbol_rate
  22. self.sample_rate = sample_rate
  23. self.sample_period = 1 / self.sample_rate
  24. self.length = length
  25. self.show_graphs = show_graphs
  26. def __get_time_domain(self, symbol_vals):
  27. samples_per_symbol = int(self.sample_rate / self.symbol_rate)
  28. samples = int(symbol_vals.shape[0] * samples_per_symbol)
  29. symbol_vals_a = np.repeat(symbol_vals, repeats=samples_per_symbol, axis=0)
  30. t = np.linspace(start=0, stop=samples * self.sample_period, num=samples)
  31. val_t = symbol_vals_a[:, 0] * np.cos(2 * math.pi * symbol_vals_a[:, 2] * t + symbol_vals_a[:, 1])
  32. return t, val_t
  33. def __time_to_frequency(self, values):
  34. val_f = fft(values)
  35. f = np.linspace(0.0, 1 / (2 * self.sample_period), (values.size // 2))
  36. f_neg = -1 * np.flip(f)
  37. f = np.concatenate((f, f_neg), axis=0)
  38. return f, val_f
  39. def __frequency_to_time(self, values):
  40. val_t = ifft(values)
  41. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  42. return t, val_t
  43. def __apply_dispersion(self, values):
  44. # Obtain fft
  45. f, val_f = self.__time_to_frequency(values)
  46. if self.show_graphs:
  47. plt.plot(f, val_f)
  48. plt.title('frequency domain (pre-distortion)')
  49. plt.show()
  50. # Apply distortion
  51. dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
  52. if self.show_graphs:
  53. plt.plot(f, dist_val_f)
  54. plt.title('frequency domain (post-distortion)')
  55. plt.show()
  56. # Inverse fft
  57. t, val_t = self.__frequency_to_time(dist_val_f)
  58. return t, val_t
  59. def __photodiode_detection(self, values):
  60. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  61. val_t = np.power(np.absolute(values), 2)
  62. return t, val_t
  63. def forward(self, values):
  64. # Converting APF representation to time-series
  65. t, val_t = self.__get_time_domain(values)
  66. if self.show_graphs:
  67. plt.plot(t, val_t)
  68. plt.title('time domain (raw)')
  69. plt.show()
  70. # Adding AWGN
  71. val_t += np.random.normal(0, 1, val_t.shape) * self.noise
  72. if self.show_graphs:
  73. plt.plot(t, val_t)
  74. plt.title('time domain (AWGN)')
  75. plt.show()
  76. # Applying chromatic dispersion
  77. t, val_t = self.__apply_dispersion(val_t)
  78. if self.show_graphs:
  79. plt.plot(t, val_t)
  80. plt.title('time domain (post-distortion)')
  81. plt.show()
  82. # Photodiode Detection
  83. t, val_t = self.__photodiode_detection(val_t)
  84. if self.show_graphs:
  85. plt.plot(t, val_t)
  86. plt.title('time domain (post-detection)')
  87. plt.show()
  88. return t, val_t
  89. if __name__ == '__main__':
  90. # Simple OOK modulation
  91. num_of_symbols = 10
  92. symbol_vals = np.zeros((num_of_symbols, 3))
  93. symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
  94. symbol_vals[:, 2] = 10e6
  95. channel = OpticalChannel(noise_level=-20, dispersion=-21.7, symbol_rate=100e3,
  96. sample_rate=500e6, length=100, show_graphs=True)
  97. time, v = channel.forward(symbol_vals)