optical_channel.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import matplotlib.pyplot as plt
  2. import defs
  3. import numpy as np
  4. import math
  5. from numpy.fft import fft, fftfreq, ifft
  6. from commpy.filters import rrcosfilter, rcosfilter, rectfilter
  7. class OpticalChannel(defs.Channel):
  8. def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
  9. sqrt_out=False, show_graphs=False, **kwargs):
  10. """
  11. :param noise_level: Noise level in dB
  12. :param dispersion: dispersion coefficient is ps^2/km
  13. :param symbol_rate: Symbol rate of modulated signal in Hz
  14. :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
  15. :param length: fibre length in km
  16. :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
  17. :param sqrt_out: Take the root of the out to compensate for photodiode detection
  18. :param show_graphs: if graphs should be displayed or not
  19. Optical Channel class constructor
  20. """
  21. super().__init__(**kwargs)
  22. self.noise = 10 ** (noise_level / 10)
  23. self.dispersion = dispersion * 1e-24 # Converting from ps^2/km to s^2/km
  24. self.symbol_rate = symbol_rate
  25. self.symbol_period = 1 / self.symbol_rate
  26. self.sample_rate = sample_rate
  27. self.sample_period = 1 / self.sample_rate
  28. self.length = length
  29. self.pulse_shape = pulse_shape.strip().lower()
  30. self.sqrt_out = sqrt_out
  31. self.show_graphs = show_graphs
  32. def __get_time_domain(self, symbol_vals):
  33. samples_per_symbol = int(self.sample_rate / self.symbol_rate)
  34. samples = int(symbol_vals.shape[0] * samples_per_symbol)
  35. symbol_impulse = np.zeros(samples)
  36. # TODO: Implement Frequency/Phase Modulation
  37. for i in range(symbol_vals.shape[0]):
  38. symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
  39. if self.pulse_shape == 'rrcos':
  40. self.filter_samples = 5 * samples_per_symbol
  41. self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  42. elif self.pulse_shape == 'rcos':
  43. self.filter_samples = 5 * samples_per_symbol
  44. self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  45. else:
  46. self.filter_samples = samples_per_symbol
  47. self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
  48. val_t = np.convolve(symbol_impulse, self.h_filter)
  49. t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
  50. return t, val_t
  51. def __time_to_frequency(self, values):
  52. val_f = fft(values)
  53. f = fftfreq(values.shape[-1])*self.sample_rate
  54. return f, val_f
  55. def __frequency_to_time(self, values):
  56. val_t = ifft(values)
  57. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  58. return t, val_t
  59. def __apply_dispersion(self, values):
  60. # Obtain fft
  61. f, val_f = self.__time_to_frequency(values)
  62. if self.show_graphs:
  63. plt.plot(f, val_f)
  64. plt.title('frequency domain (pre-distortion)')
  65. plt.show()
  66. # Apply distortion
  67. dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
  68. if self.show_graphs:
  69. plt.plot(f, dist_val_f)
  70. plt.title('frequency domain (post-distortion)')
  71. plt.show()
  72. # Inverse fft
  73. t, val_t = self.__frequency_to_time(dist_val_f)
  74. return t, val_t
  75. def __photodiode_detection(self, values):
  76. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  77. val_t = np.power(np.absolute(values), 2)
  78. return t, val_t
  79. def forward(self, values):
  80. # Converting APF representation to time-series
  81. t, val_t = self.__get_time_domain(values)
  82. if self.show_graphs:
  83. plt.plot(t, val_t)
  84. plt.title('time domain (raw)')
  85. plt.show()
  86. # Adding AWGN
  87. val_t += np.random.normal(0, 1, val_t.shape) * self.noise
  88. if self.show_graphs:
  89. plt.plot(t, val_t)
  90. plt.title('time domain (AWGN)')
  91. plt.show()
  92. # Applying chromatic dispersion
  93. t, val_t = self.__apply_dispersion(val_t)
  94. if self.show_graphs:
  95. plt.plot(t, val_t)
  96. plt.title('time domain (post-distortion)')
  97. plt.show()
  98. # Photodiode Detection
  99. t, val_t = self.__photodiode_detection(val_t)
  100. # Symbol Decisions
  101. idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
  102. self.symbol_period/self.sample_period, dtype='int16')
  103. t_descision = self.sample_period * idx
  104. if self.show_graphs:
  105. plt.plot(t, val_t)
  106. plt.title('time domain (post-detection)')
  107. plt.show()
  108. plt.plot(t, val_t)
  109. for xc in t_descision:
  110. plt.axvline(x=xc, color='r')
  111. plt.title('time domain (post-detection with decision times)')
  112. plt.show()
  113. # TODO: Implement Frequency/Phase Modulation
  114. out = np.zeros(values.shape)
  115. out[:, 0] = val_t[idx]
  116. out[:, 1] = values[:, 1]
  117. out[:, 2] = values[:, 2]
  118. if self.sqrt_out:
  119. out[:, 0] = np.sqrt(out[:, 0])
  120. return out
  121. if __name__ == '__main__':
  122. # Simple OOK modulation
  123. num_of_symbols = 100
  124. symbol_vals = np.zeros((num_of_symbols, 3))
  125. symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
  126. symbol_vals[:, 2] = 40e9
  127. channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
  128. sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
  129. v = channel.forward(symbol_vals)
  130. rx = (v > 0.5).astype(int)
  131. tru = np.sum(rx == symbol_vals[:, 0].astype(int))
  132. print("Accuracy: {}".format(tru/num_of_symbols))