custom_layers.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from tensorflow.keras import layers
  2. import tensorflow as tf
  3. import math
  4. import numpy as np
  5. class ExtractCentralMessage(layers.Layer):
  6. def __init__(self, messages_per_block, samples_per_symbol):
  7. """
  8. A keras layer that extracts the central message(symbol) in a block.
  9. :param messages_per_block: Total number of messages in transmission block
  10. :param samples_per_symbol: Number of samples per transmitted symbol
  11. """
  12. super(ExtractCentralMessage, self).__init__()
  13. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  14. i = np.identity(samples_per_symbol)
  15. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  16. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  17. temp_w[begin:end, :] = i
  18. self.samples_per_symbol = samples_per_symbol
  19. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  20. def call(self, inputs, **kwargs):
  21. out = tf.matmul(inputs, self.w)
  22. return tf.reshape(out, shape=(1, 1, self.samples_per_symbol))
  23. # TODO: this won't work with dense layers need to move to separate layer
  24. class AwgnChannel(layers.Layer):
  25. def __init__(self, rx_stddev=0.1):
  26. """
  27. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  28. being applied every time the call function is called.
  29. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  30. """
  31. super(AwgnChannel, self).__init__()
  32. self.noise_layer = layers.GaussianNoise(rx_stddev)
  33. def call(self, inputs, **kwargs):
  34. return self.noise_layer.call(inputs, training=True)
  35. class DigitizationLayer(layers.Layer):
  36. def __init__(self,
  37. fs,
  38. num_of_samples,
  39. lpf_cutoff=32e9,
  40. q_stddev=0.1):
  41. """
  42. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  43. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  44. :param fs: Sampling frequency of the simulation in Hz
  45. :param num_of_samples: Total number of samples in the input
  46. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  47. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  48. """
  49. super(DigitizationLayer, self).__init__()
  50. self.noise_layer = layers.GaussianNoise(q_stddev)
  51. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  52. temp = np.ones(freq.shape)
  53. for idx, val in np.ndenumerate(freq):
  54. if np.abs(val) > lpf_cutoff:
  55. temp[idx] = 0
  56. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  57. def call(self, inputs, **kwargs):
  58. complex_in = tf.cast(inputs, dtype=tf.complex64)
  59. val_f = tf.signal.fft(complex_in)
  60. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  61. filtered_t = tf.signal.ifft(filtered_f)
  62. real_t = tf.cast(filtered_t, dtype=tf.float32)
  63. noisy = self.noise_layer.call(real_t, training=True)
  64. return noisy
  65. class OpticalChannel(layers.Layer):
  66. def __init__(self,
  67. fs,
  68. num_of_samples,
  69. dispersion_factor,
  70. fiber_length,
  71. lpf_cutoff=32e9,
  72. rx_stddev=0.01,
  73. q_stddev=0.01):
  74. """
  75. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  76. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  77. :param fs: Sampling frequency of the simulation in Hz
  78. :param num_of_samples: Total number of samples in the input
  79. :param dispersion_factor: Dispersion factor in s^2/km
  80. :param fiber_length: Length of fiber to model in km
  81. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  82. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  83. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  84. """
  85. super(OpticalChannel, self).__init__()
  86. self.noise_layer = layers.GaussianNoise(rx_stddev)
  87. self.digitization_layer = DigitizationLayer(fs=fs,
  88. num_of_samples=num_of_samples,
  89. lpf_cutoff=lpf_cutoff,
  90. q_stddev=q_stddev)
  91. self.flatten_layer = layers.Flatten()
  92. self.fs = fs
  93. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64)
  94. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  95. def call(self, inputs, **kwargs):
  96. # DAC LPF and noise
  97. dac_out = self.digitization_layer(inputs)
  98. # Chromatic Dispersion
  99. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  100. val_f = tf.signal.fft(complex_val)
  101. disp_f = tf.math.multiply(val_f, self.multiplier)
  102. disp_t = tf.signal.ifft(disp_f)
  103. # Squared-Law Detection
  104. pd_out = tf.square(tf.abs(disp_t))
  105. # Casting back to floatx
  106. real_val = tf.cast(pd_out, dtype=tf.float32)
  107. # Adding photo-diode receiver noise
  108. rx_signal = self.noise_layer.call(real_val, training=True)
  109. # ADC LPF and noise
  110. adc_out = self.digitization_layer(rx_signal)
  111. return adc_out