optical_channel.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import matplotlib.pyplot as plt
  2. import defs
  3. import numpy as np
  4. import math
  5. from numpy.fft import fft, fftfreq, ifft
  6. from commpy.filters import rrcosfilter, rcosfilter, rectfilter
  7. from models import basic
  8. class OpticalChannel(defs.Channel):
  9. def __init__(self, noise_level, dispersion, symbol_rate, sample_rate, length, pulse_shape='rect',
  10. sqrt_out=False, show_graphs=False, **kwargs):
  11. """
  12. :param noise_level: Noise level in dB
  13. :param dispersion: dispersion coefficient is ps^2/km
  14. :param symbol_rate: Symbol rate of modulated signal in Hz
  15. :param sample_rate: Sample rate of time-domain model (time steps in simulation) in Hz
  16. :param length: fibre length in km
  17. :param pulse_shape: pulse shape -> ['rect', 'rcos', 'rrcos']
  18. :param sqrt_out: Take the root of the out to compensate for photodiode detection
  19. :param show_graphs: if graphs should be displayed or not
  20. Optical Channel class constructor
  21. """
  22. super().__init__(**kwargs)
  23. self.noise = 10 ** (noise_level / 10)
  24. self.dispersion = dispersion * 1e-24 # Converting from ps^2/km to s^2/km
  25. self.symbol_rate = symbol_rate
  26. self.symbol_period = 1 / self.symbol_rate
  27. self.sample_rate = sample_rate
  28. self.sample_period = 1 / self.sample_rate
  29. self.length = length
  30. self.pulse_shape = pulse_shape.strip().lower()
  31. self.sqrt_out = sqrt_out
  32. self.show_graphs = show_graphs
  33. def __get_time_domain(self, symbol_vals):
  34. samples_per_symbol = int(self.sample_rate / self.symbol_rate)
  35. samples = int(symbol_vals.shape[0] * samples_per_symbol)
  36. symbol_impulse = np.zeros(samples)
  37. # TODO: Implement Frequency/Phase Modulation
  38. for i in range(symbol_vals.shape[0]):
  39. symbol_impulse[i*samples_per_symbol] = symbol_vals[i, 0]
  40. if self.pulse_shape == 'rrcos':
  41. self.filter_samples = 5 * samples_per_symbol
  42. self.t_filter, self.h_filter = rrcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  43. elif self.pulse_shape == 'rcos':
  44. self.filter_samples = 5 * samples_per_symbol
  45. self.t_filter, self.h_filter = rcosfilter(self.filter_samples, 0.8, self.symbol_period, self.sample_rate)
  46. else:
  47. self.filter_samples = samples_per_symbol
  48. self.t_filter, self.h_filter = rectfilter(self.filter_samples, self.symbol_period, self.sample_rate)
  49. val_t = np.convolve(symbol_impulse, self.h_filter)
  50. t = np.linspace(start=0, stop=val_t.shape[0] * self.sample_period, num=val_t.shape[0])
  51. return t, val_t
  52. def __time_to_frequency(self, values):
  53. val_f = fft(values)
  54. f = fftfreq(values.shape[-1])*self.sample_rate
  55. return f, val_f
  56. def __frequency_to_time(self, values):
  57. val_t = ifft(values)
  58. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  59. return t, val_t
  60. def __apply_dispersion(self, values):
  61. # Obtain fft
  62. f, val_f = self.__time_to_frequency(values)
  63. if self.show_graphs:
  64. plt.plot(f, val_f)
  65. plt.title('frequency domain (pre-distortion)')
  66. plt.show()
  67. # Apply distortion
  68. dist_val_f = val_f * np.exp(0.5j * self.dispersion * self.length * np.power(2 * math.pi * f, 2))
  69. if self.show_graphs:
  70. plt.plot(f, dist_val_f)
  71. plt.title('frequency domain (post-distortion)')
  72. plt.show()
  73. # Inverse fft
  74. t, val_t = self.__frequency_to_time(dist_val_f)
  75. return t, val_t
  76. def __photodiode_detection(self, values):
  77. t = np.linspace(start=0, stop=values.size * self.sample_period, num=values.size)
  78. val_t = np.power(np.absolute(values), 2)
  79. return t, val_t
  80. def forward(self, values):
  81. if hasattr(values, 'apf'):
  82. values = values.apf
  83. # Converting APF representation to time-series
  84. t, val_t = self.__get_time_domain(values)
  85. if self.show_graphs:
  86. plt.plot(t, val_t)
  87. plt.title('time domain (raw)')
  88. plt.show()
  89. # Adding AWGN
  90. val_t += np.random.normal(0, 1, val_t.shape) * self.noise
  91. if self.show_graphs:
  92. plt.plot(t, val_t)
  93. plt.title('time domain (AWGN)')
  94. plt.show()
  95. # Applying chromatic dispersion
  96. t, val_t = self.__apply_dispersion(val_t)
  97. if self.show_graphs:
  98. plt.plot(t, val_t)
  99. plt.title('time domain (post-distortion)')
  100. plt.show()
  101. # Photodiode Detection
  102. t, val_t = self.__photodiode_detection(val_t)
  103. # Symbol Decisions
  104. idx = np.arange(self.filter_samples/2, t.shape[0] - (self.filter_samples/2),
  105. self.symbol_period/self.sample_period, dtype='int64')
  106. t_descision = self.sample_period * idx
  107. if self.show_graphs:
  108. plt.plot(t, val_t)
  109. plt.title('time domain (post-detection)')
  110. plt.show()
  111. plt.plot(t, val_t)
  112. for xc in t_descision:
  113. plt.axvline(x=xc, color='r')
  114. plt.title('time domain (post-detection with decision times)')
  115. plt.show()
  116. # TODO: Implement Frequency/Phase Modulation
  117. out = np.zeros(values.shape)
  118. out[:, 0] = val_t[idx]
  119. out[:, 1] = values[:, 1]
  120. out[:, 2] = values[:, 2]
  121. if self.sqrt_out:
  122. out[:, 0] = np.sqrt(out[:, 0])
  123. return basic.RFSignal(out)
  124. if __name__ == '__main__':
  125. # Simple OOK modulation
  126. num_of_symbols = 100
  127. symbol_vals = np.zeros((num_of_symbols, 3))
  128. symbol_vals[:, 0] = np.random.randint(2, size=symbol_vals.shape[0])
  129. symbol_vals[:, 2] = 40e9
  130. channel = OpticalChannel(noise_level=-10, dispersion=-21.7, symbol_rate=10e9,
  131. sample_rate=400e9, length=100, pulse_shape='rcos', show_graphs=True)
  132. v = channel.forward(symbol_vals)
  133. rx = (v > 0.5).astype(int)
  134. tru = np.sum(rx == symbol_vals[:, 0].astype(int))
  135. print("Accuracy: {}".format(tru/num_of_symbols))