layers.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """
  2. Custom Keras Layers for general use
  3. """
  4. import itertools
  5. from tensorflow.keras import layers
  6. import tensorflow as tf
  7. import numpy as np
  8. import math
  9. class AwgnChannel(layers.Layer):
  10. def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
  11. """
  12. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  13. being applied every time the call function is called.
  14. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  15. """
  16. super(AwgnChannel, self).__init__(**kwargs)
  17. if noise_dB is not None:
  18. # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
  19. rx_stddev = 10 ** (noise_dB / 10.0)
  20. self.noise_layer = layers.GaussianNoise(rx_stddev)
  21. def call(self, inputs, **kwargs):
  22. return self.noise_layer.call(inputs, training=True)
  23. class BitsToSymbols(layers.Layer):
  24. def __init__(self, cardinality):
  25. super(BitsToSymbols, self).__init__()
  26. self.cardinality = cardinality
  27. n = int(math.log(self.cardinality, 2))
  28. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  29. def call(self, inputs, **kwargs):
  30. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  31. out = tf.one_hot(idx, self.cardinality)
  32. return layers.Reshape((9, 32))(out)
  33. class SymbolsToBits(layers.Layer):
  34. def __init__(self, cardinality):
  35. super(SymbolsToBits, self).__init__()
  36. n = int(math.log(cardinality, 2))
  37. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  38. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  39. self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  40. def call(self, inputs, **kwargs):
  41. return tf.matmul(inputs, self.all_syms)
  42. class ScaleAndOffset(layers.Layer):
  43. """
  44. Scales and offsets a tensor
  45. """
  46. def __init__(self, scale=1, offset=0, **kwargs):
  47. super(ScaleAndOffset, self).__init__(**kwargs)
  48. self.offset = offset
  49. self.scale = scale
  50. def call(self, inputs, **kwargs):
  51. return inputs * self.scale + self.offset
  52. class ExtractCentralMessage(layers.Layer):
  53. def __init__(self, messages_per_block, samples_per_symbol):
  54. """
  55. A keras layer that extracts the central message(symbol) in a block.
  56. :param messages_per_block: Total number of messages in transmission block
  57. :param samples_per_symbol: Number of samples per transmitted symbol
  58. """
  59. super(ExtractCentralMessage, self).__init__()
  60. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  61. i = np.identity(samples_per_symbol)
  62. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  63. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  64. temp_w[begin:end, :] = i
  65. self.samples_per_symbol = samples_per_symbol
  66. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  67. def call(self, inputs, **kwargs):
  68. return tf.matmul(inputs, self.w)
  69. class DigitizationLayer(layers.Layer):
  70. def __init__(self,
  71. fs,
  72. num_of_samples,
  73. lpf_cutoff=32e9,
  74. sig_avg=0.5,
  75. enob=10):
  76. """
  77. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  78. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  79. :param fs: Sampling frequency of the simulation in Hz
  80. :param num_of_samples: Total number of samples in the input
  81. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  82. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  83. """
  84. super(DigitizationLayer, self).__init__()
  85. stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10))
  86. self.noise_layer = layers.GaussianNoise(stddev)
  87. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  88. temp = np.ones(freq.shape)
  89. for idx, val in np.ndenumerate(freq):
  90. if np.abs(val) > lpf_cutoff:
  91. temp[idx] = 0
  92. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  93. def call(self, inputs, **kwargs):
  94. complex_in = tf.cast(inputs, dtype=tf.complex64)
  95. val_f = tf.signal.fft(complex_in)
  96. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  97. filtered_t = tf.signal.ifft(filtered_f)
  98. real_t = tf.cast(filtered_t, dtype=tf.float32)
  99. noisy = self.noise_layer.call(real_t, training=True)
  100. return noisy
  101. class OpticalChannel(layers.Layer):
  102. def __init__(self,
  103. fs,
  104. num_of_samples,
  105. dispersion_factor,
  106. fiber_length,
  107. lpf_cutoff=32e9,
  108. rx_stddev=0.01,
  109. sig_avg=0.5,
  110. enob=10):
  111. """
  112. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  113. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  114. :param fs: Sampling frequency of the simulation in Hz
  115. :param num_of_samples: Total number of samples in the input
  116. :param dispersion_factor: Dispersion factor in s^2/km
  117. :param fiber_length: Length of fiber to model in km
  118. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  119. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  120. :param sig_avg: Average signal amplitude
  121. """
  122. super(OpticalChannel, self).__init__()
  123. self.rx_stddev = rx_stddev
  124. self.noise_layer = layers.GaussianNoise(self.rx_stddev)
  125. self.digitization_layer = DigitizationLayer(
  126. fs=fs,
  127. num_of_samples=num_of_samples,
  128. lpf_cutoff=lpf_cutoff,
  129. sig_avg=sig_avg,
  130. enob=enob
  131. )
  132. self.flatten_layer = layers.Flatten()
  133. self.fs = fs
  134. self.freq = tf.convert_to_tensor(
  135. np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64)
  136. self.multiplier = tf.math.exp(
  137. 0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq))
  138. def call(self, inputs, **kwargs):
  139. # DAC LPF and noise
  140. dac_out = self.digitization_layer(inputs)
  141. # Chromatic Dispersion
  142. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  143. val_f = tf.signal.fft(complex_val)
  144. disp_f = tf.math.multiply(val_f, self.multiplier)
  145. disp_t = tf.signal.ifft(disp_f)
  146. # Squared-Law Detection
  147. pd_out = tf.square(tf.abs(disp_t))
  148. # Casting back to floatx
  149. real_val = tf.cast(pd_out, dtype=tf.float32)
  150. # Adding photo-diode receiver noise
  151. rx_signal = self.noise_layer.call(real_val, training=True)
  152. # ADC LPF and noise
  153. adc_out = self.digitization_layer(rx_signal)
  154. return adc_out