optical_channel.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import matplotlib.pyplot as plt
  2. import defs
  3. import numpy as np
  4. import math
  5. from numpy.fft import fft, fftfreq, ifft
  6. from commpy.filters import rrcosfilter, rcosfilter, rectfilter
  7. class OpticalChannel(defs.Channel):
  8. def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
  9. sqrt_out=False, show_graphs=False, **kwargs):
  10. """
  11. :param noise_level: Noise level in dB
  12. :param dispersion: dispersion coefficient is ps^2/km
  13. :param symbol_rate: Symbol rate of modulated signal in Hz
  14. :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
  15. :param length: fibre length in km
  16. :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
  17. :param sqrt_out: Take the root of the out to compensate for photodiode detection
  18. :param show_graphs: if graphs should be displayed or not
  19. Optical Channel class constructor
  20. """
  21. super().__init__(**kwargs)
  22. self.noise = 10 ** (noise_level / 10)
  23. self.dispersion = dispersion * 1e-24 # Converting from ps^2/km to s^2/km
  24. self.symbol_rate = symbol_rate
  25. self.symbol_period = 1 / self.symbol_rate
  26. self.sample_rate = sample_rate
  27. self.sample_period = 1 / self.sample_rate
  28. self.length = length
  29. self.pulse_shape = pulse_shape.strip().lower()
  30. self.sqrt_out = sqrt_out
  31. self.show_graphs = show_graphs
  32. def __get_time_domain(self, symbol_vals):
  33. samples_per_symbol = int(self.sample_rate / self.symbol_rate)
  34. samples = int(symbol_vals.shape[0] * samples_per_symbol)
  35. symbol_impulse = np.zeros(samples)
  36. # TODO: Implement Frequency/Phase Modulation
  37. for i in range(symbol_vals.shape[0]):
  38. symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
  39. if self.pulse_shape == 'rrcos':
  40. self.filter_samples = 5 * samples_per_symbol
  41. self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  42. elif self.pulse_shape == 'rcos':
  43. self.filter_samples = 5 * samples_per_symbol
  44. self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  45. else:
  46. self.filter_samples = samples_per_symbol
  47. self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
  48. val_t = np.convolve(symbol_impulse, self.h_filter)
  49. t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
  50. return t, val_t
  51. def __time_to_frequency(self, values):
  52. val_f = fft(values)
  53. f = fftfreq(values.shape[-1])*self.sample_rate
  54. return f, val_f
  55. def __frequency_to_time(self, values):
  56. val_t = ifft(values)
  57. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  58. return t, val_t
  59. def __apply_dispersion(self, values):
  60. # Obtain fft
  61. f, val_f = self.__time_to_frequency(values)
  62. if self.show_graphs:
  63. plt.plot(f, val_f)
  64. plt.title('frequency domain (pre-distortion)')
  65. plt.show()
  66. # Apply distortion
  67. dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
  68. if self.show_graphs:
  69. plt.plot(f, dist_val_f)
  70. plt.title('frequency domain (post-distortion)')
  71. plt.show()
  72. # Inverse fft
  73. t, val_t = self.__frequency_to_time(dist_val_f)
  74. return t, val_t
  75. def __photodiode_detection(self, values):
  76. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  77. val_t = np.power(np.absolute(values), 2)
  78. return t, val_t
  79. def __plot_eye(self, val_t, num_of_symbols=100):
  80. samples_per_symbol = int(self.symbol_period/self.sample_period)
  81. val_t_a = np.reshape(val_t[:(2*samples_per_symbol*num_of_symbols)], (-1, 2*samples_per_symbol))
  82. t = np.linspace(start=0, stop=self.symbol_period, num=2*samples_per_symbol)
  83. for sym in val_t_a:
  84. plt.plot(t, sym, color="blue")
  85. plt.show()
  86. pass
  87. # def eye_diagram(self, t, val_t, num_of_symbols=100):
  88. # symbol_width = int(len(val_t)/num_of_symbols)
  89. # time_scale = t[0:symbol_width]
  90. # counter = 0
  91. # l = 0
  92. # u = symbol_width
  93. # while counter < 100:
  94. # symbol = val_t[l:u]
  95. # plt.plot(time_scale, symbol)
  96. # counter += 1
  97. # l += symbol_width
  98. # u += symbol_width
  99. def forward(self, values):
  100. # Converting APF representation to time-series
  101. t, val_t = self.__get_time_domain(values)
  102. if self.show_graphs:
  103. plt.plot(t, val_t)
  104. plt.title('time domain (raw)')
  105. plt.show()
  106. # Adding AWGN
  107. val_t += np.random.normal(0, 1, val_t.shape) * self.noise
  108. if self.show_graphs:
  109. plt.plot(t, val_t)
  110. plt.title('time domain (AWGN)')
  111. plt.show()
  112. # Applying chromatic dispersion
  113. t, val_t = self.__apply_dispersion(val_t)
  114. if self.show_graphs:
  115. plt.plot(t, val_t)
  116. plt.title('time domain (post-distortion)')
  117. plt.show()
  118. # Photodiode Detection
  119. t, val_t = self.__photodiode_detection(val_t)
  120. # Symbol Decisions
  121. idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
  122. self.symbol_period/self.sample_period, dtype='int64')
  123. t_descision = self.sample_period * idx
  124. if self.show_graphs:
  125. self.__plot_eye(val_t)
  126. plt.plot(t, val_t)
  127. plt.title('time domain (post-detection)')
  128. plt.show()
  129. plt.plot(t, val_t)
  130. for xc in t_descision:
  131. plt.axvline(x=xc, color='r')
  132. plt.title('time domain (post-detection with decision times)')
  133. plt.show()
  134. # TODO: Implement Frequency/Phase Modulation
  135. out = np.zeros(values.shape)
  136. out[:, 0] = val_t[idx]
  137. out[:, 1] = values[:, 1]
  138. out[:, 2] = values[:, 2]
  139. if self.sqrt_out:
  140. out[:, 0] = np.sqrt(out[:, 0])
  141. return out
  142. if __name__ == '__main__':
  143. # Simple OOK modulation
  144. num_of_symbols = 1000
  145. symbol_vals = np.zeros((num_of_symbols, 3))
  146. symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
  147. symbol_vals[:, 2] = 0
  148. channel = OpticalChannel(noise_level=-20, dispersion=-21.7, symbol_rate=7e9,
  149. sample_rate=336e9, length=0, pulse_shape='rcos', show_graphs=True)
  150. v = channel.forward(symbol_vals)
  151. rx = (v > 0.5).astype(int)
  152. tru = np.sum(rx == symbol_vals[:, 0].astype(int))
  153. print("Accuracy: {}".format(tru/num_of_symbols))