layers.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. """
  2. Custom Keras Layers for general use
  3. """
  4. import itertools
  5. from tensorflow.keras import layers
  6. import tensorflow as tf
  7. import numpy as np
  8. import math
  9. class AwgnChannel(layers.Layer):
  10. def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
  11. """
  12. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  13. being applied every time the call function is called.
  14. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  15. """
  16. super(AwgnChannel, self).__init__(**kwargs)
  17. if noise_dB is not None:
  18. # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
  19. rx_stddev = 10 ** (noise_dB / 10.0)
  20. self.noise_layer = layers.GaussianNoise(rx_stddev)
  21. def call(self, inputs, **kwargs):
  22. return self.noise_layer.call(inputs, training=True)
  23. class BitsToSymbols(layers.Layer):
  24. def __init__(self, cardinality, messages_per_block):
  25. super(BitsToSymbols, self).__init__()
  26. self.cardinality = cardinality
  27. self.messages_per_block = messages_per_block
  28. n = int(math.log(self.cardinality, 2))
  29. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  30. def call(self, inputs, **kwargs):
  31. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  32. out = tf.one_hot(idx, self.cardinality)
  33. return layers.Reshape((self.messages_per_block, self.cardinality))(out)
  34. class SymbolsToBits(layers.Layer):
  35. def __init__(self, cardinality):
  36. super(SymbolsToBits, self).__init__()
  37. n = int(math.log(cardinality, 2))
  38. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  39. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  40. self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  41. def call(self, inputs, **kwargs):
  42. return tf.matmul(inputs, self.all_syms)
  43. class ScaleAndOffset(layers.Layer):
  44. """
  45. Scales and offsets a tensor
  46. """
  47. def __init__(self, scale=1, offset=0, **kwargs):
  48. super(ScaleAndOffset, self).__init__(**kwargs)
  49. self.offset = offset
  50. self.scale = scale
  51. def call(self, inputs, **kwargs):
  52. return inputs * self.scale + self.offset
  53. class ExtractCentralMessage(layers.Layer):
  54. def __init__(self, messages_per_block, samples_per_symbol):
  55. """
  56. A keras layer that extracts the central message(symbol) in a block.
  57. :param messages_per_block: Total number of messages in transmission block
  58. :param samples_per_symbol: Number of samples per transmitted symbol
  59. """
  60. super(ExtractCentralMessage, self).__init__()
  61. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  62. i = np.identity(samples_per_symbol)
  63. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  64. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  65. temp_w[begin:end, :] = i
  66. self.samples_per_symbol = samples_per_symbol
  67. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  68. def call(self, inputs, **kwargs):
  69. return tf.matmul(inputs, self.w)
  70. class DigitizationLayer(layers.Layer):
  71. def __init__(self,
  72. fs,
  73. num_of_samples,
  74. lpf_cutoff=32e9,
  75. sig_avg=0.5,
  76. enob=10):
  77. """
  78. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  79. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  80. :param fs: Sampling frequency of the simulation in Hz
  81. :param num_of_samples: Total number of samples in the input
  82. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  83. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  84. """
  85. super(DigitizationLayer, self).__init__()
  86. stddev = 3 * (sig_avg ** 2) * (10 ** ((-6.02 * enob + 1.76) / 10))
  87. self.noise_layer = layers.GaussianNoise(stddev)
  88. freq = np.fft.fftfreq(num_of_samples, d=1 / fs)
  89. temp = np.ones(freq.shape)
  90. for idx, val in np.ndenumerate(freq):
  91. if np.abs(val) > lpf_cutoff:
  92. temp[idx] = 0
  93. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  94. def call(self, inputs, **kwargs):
  95. complex_in = tf.cast(inputs, dtype=tf.complex64)
  96. val_f = tf.signal.fft(complex_in)
  97. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  98. filtered_t = tf.signal.ifft(filtered_f)
  99. real_t = tf.cast(filtered_t, dtype=tf.float32)
  100. noisy = self.noise_layer.call(real_t, training=True)
  101. return noisy
  102. class OpticalChannel(layers.Layer):
  103. def __init__(self,
  104. fs,
  105. num_of_samples,
  106. dispersion_factor,
  107. fiber_length,
  108. fiber_length_stddev=0,
  109. lpf_cutoff=32e9,
  110. rx_stddev=0.01,
  111. sig_avg=0.5,
  112. enob=10):
  113. """
  114. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  115. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  116. :param fs: Sampling frequency of the simulation in Hz
  117. :param num_of_samples: Total number of samples in the input
  118. :param dispersion_factor: Dispersion factor in s^2/km
  119. :param fiber_length: Length of fiber to model in km
  120. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  121. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  122. :param sig_avg: Average signal amplitude
  123. """
  124. super(OpticalChannel, self).__init__()
  125. self.fs = fs
  126. self.num_of_samples = num_of_samples
  127. self.dispersion_factor = dispersion_factor
  128. self.fiber_length = tf.cast(fiber_length, dtype=tf.float32)
  129. self.fiber_length_stddev = tf.cast(fiber_length_stddev, dtype=tf.float32)
  130. self.lpf_cutoff = lpf_cutoff
  131. self.rx_stddev = rx_stddev
  132. self.sig_avg = sig_avg
  133. self.enob = enob
  134. self.noise_layer = layers.GaussianNoise(self.rx_stddev)
  135. self.fiber_length_noise = layers.GaussianNoise(self.fiber_length_stddev)
  136. self.digitization_layer = DigitizationLayer(fs=self.fs,
  137. num_of_samples=self.num_of_samples,
  138. lpf_cutoff=self.lpf_cutoff,
  139. sig_avg=self.sig_avg,
  140. enob=self.enob)
  141. self.flatten_layer = layers.Flatten()
  142. self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
  143. # self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
  144. def call(self, inputs, **kwargs):
  145. # DAC LPF and noise
  146. dac_out = self.digitization_layer(inputs)
  147. # Chromatic Dispersion
  148. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  149. val_f = tf.signal.fft(complex_val)
  150. len = tf.cast(self.fiber_length_noise.call(self.fiber_length), dtype=tf.complex64)
  151. multiplier = tf.math.exp(0.5j*self.dispersion_factor*len*tf.math.square(2*math.pi*self.freq))
  152. disp_f = tf.math.multiply(val_f, multiplier)
  153. disp_t = tf.signal.ifft(disp_f)
  154. # Squared-Law Detection
  155. pd_out = tf.square(tf.abs(disp_t))
  156. # Casting back to floatx
  157. real_val = tf.cast(pd_out, dtype=tf.float32)
  158. # Adding photo-diode receiver noise
  159. rx_signal = self.noise_layer.call(real_val, training=True)
  160. # ADC LPF and noise
  161. adc_out = self.digitization_layer(rx_signal)
  162. return adc_out