optical_channel.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import matplotlib.pyplot as plt
  2. import defs
  3. import numpy as np
  4. import math
  5. from numpy.fft import fft, fftfreq, ifft
  6. from commpy.filters import rrcosfilter, rcosfilter, rectfilter
  7. class OpticalChannel(defs.Channel):
  8. def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
  9. show_graphs=False, **kwargs):
  10. """
  11. :param noise_level: Noise level in dB
  12. :param dispersion: dispersion coefficient is ps^2/km
  13. :param symbol_rate: Symbol rate of modulated signal in Hz
  14. :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
  15. :param length: fibre length in km
  16. :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
  17. :param show_graphs: if graphs should be displayed or not
  18. Optical Channel class constructor
  19. """
  20. super().__init__(**kwargs)
  21. self.noise = 10 ** (noise_level / 10)
  22. self.dispersion = dispersion * 1e-24 # Converting from ps^2/km to s^2/km
  23. self.symbol_rate = symbol_rate
  24. self.symbol_period = 1 / self.symbol_rate
  25. self.sample_rate = sample_rate
  26. self.sample_period = 1 / self.sample_rate
  27. self.length = length
  28. self.pulse_shape = pulse_shape.strip().lower()
  29. self.show_graphs = show_graphs
  30. def __get_time_domain(self, symbol_vals):
  31. samples_per_symbol = int(self.sample_rate / self.symbol_rate)
  32. samples = int(symbol_vals.shape[0] * samples_per_symbol)
  33. symbol_impulse = np.zeros(samples)
  34. for i in range(symbol_vals.shape[0]):
  35. symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
  36. if self.pulse_shape == 'rrcos':
  37. self.filter_samples = 5 * samples_per_symbol
  38. self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  39. elif self.pulse_shape == 'rcos':
  40. self.filter_samples = 5 * samples_per_symbol
  41. self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  42. else:
  43. self.filter_samples = samples_per_symbol
  44. self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
  45. val_t = np.convolve(symbol_impulse, self.h_filter)
  46. t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
  47. return t, val_t
  48. def __time_to_frequency(self, values):
  49. val_f = fft(values)
  50. f = fftfreq(values.shape[-1])*self.sample_rate
  51. return f, val_f
  52. def __frequency_to_time(self, values):
  53. val_t = ifft(values)
  54. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  55. return t, val_t
  56. def __apply_dispersion(self, values):
  57. # Obtain fft
  58. f, val_f = self.__time_to_frequency(values)
  59. if self.show_graphs:
  60. plt.plot(f, val_f)
  61. plt.title('frequency domain (pre-distortion)')
  62. plt.show()
  63. # Apply distortion
  64. dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
  65. if self.show_graphs:
  66. plt.plot(f, dist_val_f)
  67. plt.title('frequency domain (post-distortion)')
  68. plt.show()
  69. # Inverse fft
  70. t, val_t = self.__frequency_to_time(dist_val_f)
  71. return t, val_t
  72. def __photodiode_detection(self, values):
  73. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  74. val_t = np.power(np.absolute(values), 2)
  75. return t, val_t
  76. def forward(self, values):
  77. # Converting APF representation to time-series
  78. t, val_t = self.__get_time_domain(values)
  79. if self.show_graphs:
  80. plt.plot(t, val_t)
  81. plt.title('time domain (raw)')
  82. plt.show()
  83. # Adding AWGN
  84. val_t += np.random.normal(0, 1, val_t.shape) * self.noise
  85. if self.show_graphs:
  86. plt.plot(t, val_t)
  87. plt.title('time domain (AWGN)')
  88. plt.show()
  89. # Applying chromatic dispersion
  90. t, val_t = self.__apply_dispersion(val_t)
  91. if self.show_graphs:
  92. plt.plot(t, val_t)
  93. plt.title('time domain (post-distortion)')
  94. plt.show()
  95. # Photodiode Detection
  96. t, val_t = self.__photodiode_detection(val_t)
  97. # Symbol Decisions
  98. idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
  99. self.symbol_period/self.sample_period, dtype='int16')
  100. t_descision = self.sample_period * idx
  101. if self.show_graphs:
  102. plt.plot(t, val_t)
  103. plt.title('time domain (post-detection)')
  104. plt.show()
  105. plt.plot(t, val_t)
  106. for xc in t_descision:
  107. plt.axvline(x=xc, color='r')
  108. plt.title('time domain (post-detection with decision times)')
  109. plt.show()
  110. return val_t[idx]
  111. if __name__ == '__main__':
  112. # Simple OOK modulation
  113. num_of_symbols = 100
  114. symbol_vals = np.zeros((num_of_symbols, 3))
  115. symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
  116. symbol_vals[:, 2] = 40e9
  117. channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
  118. sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
  119. v = channel.forward(symbol_vals)
  120. rx = (v > 0.5).astype(int)
  121. tru = np.sum(rx == symbol_vals[:, 0].astype(int))
  122. print("Accuracy: {}".format(tru/num_of_symbols))