custom_layers.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from tensorflow.keras import layers
  2. import tensorflow as tf
  3. import math
  4. import numpy as np
  5. import itertools
  6. class BitsToSymbols(layers.Layer):
  7. def __init__(self, cardinality, messages_per_block):
  8. super(BitsToSymbols, self).__init__()
  9. self.cardinality = cardinality
  10. self.messages_per_block = messages_per_block
  11. n = int(math.log(self.cardinality, 2))
  12. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  13. def call(self, inputs, **kwargs):
  14. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  15. out = tf.one_hot(idx, self.cardinality)
  16. return layers.Reshape((self.messages_per_block, self.cardinality))(out)
  17. class SymbolsToBits(layers.Layer):
  18. def __init__(self, cardinality):
  19. super(SymbolsToBits, self).__init__()
  20. n = int(math.log(cardinality, 2))
  21. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  22. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  23. self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  24. def call(self, inputs, **kwargs):
  25. return tf.matmul(inputs, self.all_syms)
  26. class ExtractCentralMessage(layers.Layer):
  27. def __init__(self, messages_per_block, samples_per_symbol):
  28. """
  29. A keras layer that extracts the central message(symbol) in a block.
  30. :param messages_per_block: Total number of messages in transmission block
  31. :param samples_per_symbol: Number of samples per transmitted symbol
  32. """
  33. super(ExtractCentralMessage, self).__init__()
  34. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  35. i = np.identity(samples_per_symbol)
  36. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  37. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  38. temp_w[begin:end, :] = i
  39. self.samples_per_symbol = samples_per_symbol
  40. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  41. def call(self, inputs, **kwargs):
  42. return tf.matmul(inputs, self.w)
  43. class AwgnChannel(layers.Layer):
  44. def __init__(self, rx_stddev=0.1):
  45. """
  46. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  47. being applied every time the call function is called.
  48. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  49. """
  50. super(AwgnChannel, self).__init__()
  51. self.noise_layer = layers.GaussianNoise(rx_stddev)
  52. def call(self, inputs, **kwargs):
  53. return self.noise_layer.call(inputs, training=True)
  54. class DigitizationLayer(layers.Layer):
  55. def __init__(self,
  56. fs,
  57. num_of_samples,
  58. lpf_cutoff=32e9,
  59. sig_avg=0.5,
  60. enob=10):
  61. """
  62. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  63. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  64. :param fs: Sampling frequency of the simulation in Hz
  65. :param num_of_samples: Total number of samples in the input
  66. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  67. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  68. """
  69. super(DigitizationLayer, self).__init__()
  70. stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
  71. self.noise_layer = layers.GaussianNoise(stddev)
  72. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  73. temp = np.ones(freq.shape)
  74. for idx, val in np.ndenumerate(freq):
  75. if np.abs(val) > lpf_cutoff:
  76. temp[idx] = 0
  77. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  78. def call(self, inputs, **kwargs):
  79. complex_in = tf.cast(inputs, dtype=tf.complex64)
  80. val_f = tf.signal.fft(complex_in)
  81. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  82. filtered_t = tf.signal.ifft(filtered_f)
  83. real_t = tf.cast(filtered_t, dtype=tf.float32)
  84. noisy = self.noise_layer.call(real_t, training=True)
  85. return noisy
  86. class OpticalChannel(layers.Layer):
  87. def __init__(self,
  88. fs,
  89. num_of_samples,
  90. dispersion_factor,
  91. fiber_length,
  92. lpf_cutoff=32e9,
  93. rx_stddev=0.01,
  94. sig_avg=0.5,
  95. enob=10):
  96. """
  97. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  98. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  99. :param fs: Sampling frequency of the simulation in Hz
  100. :param num_of_samples: Total number of samples in the input
  101. :param dispersion_factor: Dispersion factor in s^2/km
  102. :param fiber_length: Length of fiber to model in km
  103. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  104. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  105. :param sig_avg: Average signal amplitude
  106. """
  107. super(OpticalChannel, self).__init__()
  108. self.fs = fs
  109. self.num_of_samples = num_of_samples
  110. self.dispersion_factor = dispersion_factor
  111. self.fiber_length = fiber_length
  112. self.lpf_cutoff = lpf_cutoff
  113. self.rx_stddev = rx_stddev
  114. self.sig_avg = sig_avg
  115. self.enob = enob
  116. self.noise_layer = layers.GaussianNoise(self.rx_stddev)
  117. self.digitization_layer = DigitizationLayer(fs=self.fs,
  118. num_of_samples=self.num_of_samples,
  119. lpf_cutoff=self.lpf_cutoff,
  120. sig_avg=self.sig_avg,
  121. enob=self.enob)
  122. self.flatten_layer = layers.Flatten()
  123. self.freq = tf.convert_to_tensor(np.fft.fftfreq(self.num_of_samples, d=1/fs), dtype=tf.complex64)
  124. self.multiplier = tf.math.exp(0.5j*self.dispersion_factor*self.fiber_length*tf.math.square(2*math.pi*self.freq))
  125. def call(self, inputs, **kwargs):
  126. # DAC LPF and noise
  127. dac_out = self.digitization_layer(inputs)
  128. # Chromatic Dispersion
  129. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  130. val_f = tf.signal.fft(complex_val)
  131. disp_f = tf.math.multiply(val_f, self.multiplier)
  132. disp_t = tf.signal.ifft(disp_f)
  133. # Squared-Law Detection
  134. pd_out = tf.square(tf.abs(disp_t))
  135. # Casting back to floatx
  136. real_val = tf.cast(pd_out, dtype=tf.float32)
  137. # Adding photo-diode receiver noise
  138. rx_signal = self.noise_layer.call(real_val, training=True)
  139. # ADC LPF and noise
  140. adc_out = self.digitization_layer(rx_signal)
  141. return adc_out