custom_layers.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from tensorflow.keras import layers
  2. import tensorflow as tf
  3. import math
  4. import numpy as np
  5. import itertools
  6. class BitsToSymbols(layers.Layer):
  7. def __init__(self, cardinality):
  8. super(BitsToSymbols, self).__init__()
  9. self.cardinality = cardinality
  10. n = int(math.log(self.cardinality, 2))
  11. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  12. def call(self, inputs, **kwargs):
  13. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  14. return tf.one_hot(idx, self.cardinality)
  15. class SymbolsToBits(layers.Layer):
  16. def __init__(self, cardinality):
  17. super(SymbolsToBits, self).__init__()
  18. n = int(math.log(cardinality, 2))
  19. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  20. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  21. self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32))
  22. def call(self, inputs, **kwargs):
  23. return tf.matmul(self.all_syms, inputs)
  24. class ExtractCentralMessage(layers.Layer):
  25. def __init__(self, messages_per_block, samples_per_symbol):
  26. """
  27. A keras layer that extracts the central message(symbol) in a block.
  28. :param messages_per_block: Total number of messages in transmission block
  29. :param samples_per_symbol: Number of samples per transmitted symbol
  30. """
  31. super(ExtractCentralMessage, self).__init__()
  32. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  33. i = np.identity(samples_per_symbol)
  34. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  35. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  36. temp_w[begin:end, :] = i
  37. self.samples_per_symbol = samples_per_symbol
  38. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  39. def call(self, inputs, **kwargs):
  40. return tf.matmul(inputs, self.w)
  41. class AwgnChannel(layers.Layer):
  42. def __init__(self, rx_stddev=0.1):
  43. """
  44. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  45. being applied every time the call function is called.
  46. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  47. """
  48. super(AwgnChannel, self).__init__()
  49. self.noise_layer = layers.GaussianNoise(rx_stddev)
  50. def call(self, inputs, **kwargs):
  51. return self.noise_layer.call(inputs, training=True)
  52. class DigitizationLayer(layers.Layer):
  53. def __init__(self,
  54. fs,
  55. num_of_samples,
  56. lpf_cutoff=32e9,
  57. q_stddev=0.1):
  58. """
  59. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  60. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  61. :param fs: Sampling frequency of the simulation in Hz
  62. :param num_of_samples: Total number of samples in the input
  63. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  64. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  65. """
  66. super(DigitizationLayer, self).__init__()
  67. self.noise_layer = layers.GaussianNoise(q_stddev)
  68. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  69. temp = np.ones(freq.shape)
  70. for idx, val in np.ndenumerate(freq):
  71. if np.abs(val) > lpf_cutoff:
  72. temp[idx] = 0
  73. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  74. def call(self, inputs, **kwargs):
  75. complex_in = tf.cast(inputs, dtype=tf.complex64)
  76. val_f = tf.signal.fft(complex_in)
  77. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  78. filtered_t = tf.signal.ifft(filtered_f)
  79. real_t = tf.cast(filtered_t, dtype=tf.float32)
  80. noisy = self.noise_layer.call(real_t, training=True)
  81. return noisy
  82. class OpticalChannel(layers.Layer):
  83. def __init__(self,
  84. fs,
  85. num_of_samples,
  86. dispersion_factor,
  87. fiber_length,
  88. lpf_cutoff=32e9,
  89. rx_stddev=0.01,
  90. q_stddev=0.01):
  91. """
  92. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  93. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  94. :param fs: Sampling frequency of the simulation in Hz
  95. :param num_of_samples: Total number of samples in the input
  96. :param dispersion_factor: Dispersion factor in s^2/km
  97. :param fiber_length: Length of fiber to model in km
  98. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  99. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  100. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  101. """
  102. super(OpticalChannel, self).__init__()
  103. self.noise_layer = layers.GaussianNoise(rx_stddev)
  104. self.digitization_layer = DigitizationLayer(fs=fs,
  105. num_of_samples=num_of_samples,
  106. lpf_cutoff=lpf_cutoff,
  107. q_stddev=q_stddev)
  108. self.flatten_layer = layers.Flatten()
  109. self.fs = fs
  110. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64)
  111. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  112. def call(self, inputs, **kwargs):
  113. # DAC LPF and noise
  114. dac_out = self.digitization_layer(inputs)
  115. # Chromatic Dispersion
  116. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  117. val_f = tf.signal.fft(complex_val)
  118. disp_f = tf.math.multiply(val_f, self.multiplier)
  119. disp_t = tf.signal.ifft(disp_f)
  120. # Squared-Law Detection
  121. pd_out = tf.square(tf.abs(disp_t))
  122. # Casting back to floatx
  123. real_val = tf.cast(pd_out, dtype=tf.float32)
  124. # Adding photo-diode receiver noise
  125. rx_signal = self.noise_layer.call(real_val, training=True)
  126. # ADC LPF and noise
  127. adc_out = self.digitization_layer(rx_signal)
  128. return adc_out