layers.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. Custom Keras Layers for general use
  3. """
  4. import itertools
  5. from tensorflow.keras import layers
  6. import tensorflow as tf
  7. import numpy as np
  8. class AwgnChannel(layers.Layer):
  9. def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
  10. """
  11. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  12. being applied every time the call function is called.
  13. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  14. """
  15. super(AwgnChannel, self).__init__(**kwargs)
  16. if noise_dB is not None:
  17. # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
  18. rx_stddev = 10 ** (noise_dB / 10.0)
  19. self.noise_layer = layers.GaussianNoise(rx_stddev)
  20. def call(self, inputs, **kwargs):
  21. return self.noise_layer.call(inputs, training=True)
  22. class BitsToSymbols(layers.Layer):
  23. def __init__(self, cardinality):
  24. super(BitsToSymbols, self).__init__()
  25. self.cardinality = cardinality
  26. n = int(tf.math.log(self.cardinality, 2))
  27. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  28. def call(self, inputs, **kwargs):
  29. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  30. out = tf.one_hot(idx, self.cardinality)
  31. return layers.Reshape((9, 32))(out)
  32. class SymbolsToBits(layers.Layer):
  33. def __init__(self, cardinality):
  34. super(SymbolsToBits, self).__init__()
  35. n = int(tf.math.log(cardinality, 2))
  36. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  37. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  38. self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  39. def call(self, inputs, **kwargs):
  40. return tf.matmul(inputs, self.all_syms)
  41. class ScaleAndOffset(layers.Layer):
  42. """
  43. Scales and offsets a tensor
  44. """
  45. def __init__(self, scale=1, offset=0, **kwargs):
  46. super(ScaleAndOffset, self).__init__(**kwargs)
  47. self.offset = offset
  48. self.scale = scale
  49. def call(self, inputs, **kwargs):
  50. return inputs * self.scale + self.offset
  51. class BitsToSymbol(layers.Layer):
  52. def __init__(self, cardinality, **kwargs):
  53. super().__init__(**kwargs)
  54. self.cardinality = cardinality
  55. n = int(np.log(self.cardinality, 2))
  56. self.powers = tf.convert_to_tensor(
  57. np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
  58. dtype=tf.float32
  59. )
  60. def call(self, inputs, **kwargs):
  61. idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
  62. return tf.one_hot(idx, self.cardinality)
  63. class SymbolToBits(layers.Layer):
  64. def __init__(self, cardinality, **kwargs):
  65. super().__init__(**kwargs)
  66. n = int(np.log(cardinality, 2))
  67. l = [list(i) for i in itertools.product([0, 1], repeat=n)]
  68. self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
  69. def call(self, inputs, **kwargs):
  70. return tf.matmul(self.all_syms, inputs)
  71. class ExtractCentralMessage(layers.Layer):
  72. def __init__(self, messages_per_block, samples_per_symbol):
  73. """
  74. A keras layer that extracts the central message(symbol) in a block.
  75. :param messages_per_block: Total number of messages in transmission block
  76. :param samples_per_symbol: Number of samples per transmitted symbol
  77. """
  78. super(ExtractCentralMessage, self).__init__()
  79. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  80. i = np.identity(samples_per_symbol)
  81. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  82. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  83. temp_w[begin:end, :] = i
  84. self.samples_per_symbol = samples_per_symbol
  85. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  86. def call(self, inputs, **kwargs):
  87. return tf.matmul(inputs, self.w)
  88. class DigitizationLayer(layers.Layer):
  89. def __init__(self,
  90. fs,
  91. num_of_samples,
  92. lpf_cutoff=32e9,
  93. sig_avg=0.5,
  94. enob=10):
  95. """
  96. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  97. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  98. :param fs: Sampling frequency of the simulation in Hz
  99. :param num_of_samples: Total number of samples in the input
  100. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  101. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  102. """
  103. super(DigitizationLayer, self).__init__()
  104. stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10))
  105. self.noise_layer = layers.GaussianNoise(stddev)
  106. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  107. temp = np.ones(freq.shape)
  108. for idx, val in np.ndenumerate(freq):
  109. if np.abs(val) > lpf_cutoff:
  110. temp[idx] = 0
  111. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  112. def call(self, inputs, **kwargs):
  113. complex_in = tf.cast(inputs, dtype=tf.complex64)
  114. val_f = tf.signal.fft(complex_in)
  115. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  116. filtered_t = tf.signal.ifft(filtered_f)
  117. real_t = tf.cast(filtered_t, dtype=tf.float32)
  118. noisy = self.noise_layer.call(real_t, training=True)
  119. return noisy
  120. class OpticalChannel(layers.Layer):
  121. def __init__(self,
  122. fs,
  123. num_of_samples,
  124. dispersion_factor,
  125. fiber_length,
  126. lpf_cutoff=32e9,
  127. rx_stddev=0.01,
  128. sig_avg=0.5,
  129. enob=10):
  130. """
  131. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  132. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  133. :param fs: Sampling frequency of the simulation in Hz
  134. :param num_of_samples: Total number of samples in the input
  135. :param dispersion_factor: Dispersion factor in s^2/km
  136. :param fiber_length: Length of fiber to model in km
  137. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  138. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  139. :param sig_avg: Average signal amplitude
  140. """
  141. super(OpticalChannel, self).__init__()
  142. self.rx_stddev = rx_stddev
  143. self.noise_layer = layers.GaussianNoise(self.rx_stddev)
  144. self.digitization_layer = DigitizationLayer(
  145. fs=fs,
  146. num_of_samples=num_of_samples,
  147. lpf_cutoff=lpf_cutoff,
  148. sig_avg=sig_avg,
  149. enob=enob
  150. )
  151. self.flatten_layer = layers.Flatten()
  152. self.fs = fs
  153. self.freq = tf.convert_to_tensor(
  154. np.fft.fftfreq(num_of_samples, d=1 / fs), dtype=tf.complex64)
  155. self.multiplier = tf.math.exp(
  156. 0.5j * dispersion_factor * fiber_length * tf.math.square(2 * np.pi * self.freq))
  157. def call(self, inputs, **kwargs):
  158. # DAC LPF and noise
  159. dac_out = self.digitization_layer(inputs)
  160. # Chromatic Dispersion
  161. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  162. val_f = tf.signal.fft(complex_val)
  163. disp_f = tf.math.multiply(val_f, self.multiplier)
  164. disp_t = tf.signal.ifft(disp_f)
  165. # Squared-Law Detection
  166. pd_out = tf.square(tf.abs(disp_t))
  167. # Casting back to floatx
  168. real_val = tf.cast(pd_out, dtype=tf.float32)
  169. # Adding photo-diode receiver noise
  170. rx_signal = self.noise_layer.call(real_val, training=True)
  171. # ADC LPF and noise
  172. adc_out = self.digitization_layer(rx_signal)
  173. return adc_out