custom_layers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from tensorflow.keras import layers
  2. import tensorflow as tf
  3. import math
  4. import numpy as np
  5. import itertools
  6. class BitsToSymbols(layers.Layer):
  7. def __init__(self, cardinality):
  8. super(BitsToSymbols, self).__init__()
  9. self.cardinality = cardinality
  10. n = int(math.log(self.cardinality, 2))
  11. self.pows = tf.convert_to_tensor(np.power(2, np.linspace(n-1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  12. def call(self, inputs, **kwargs):
  13. idx = tf.cast(tf.tensordot(inputs, self.pows, axes=1), dtype=tf.int32)
  14. out = tf.one_hot(idx, self.cardinality)
  15. return layers.Reshape((9, 32))(out)
  16. class SymbolsToBits(layers.Layer):
  17. def __init__(self, cardinality):
  18. super(SymbolsToBits, self).__init__()
  19. n = int(math.log(cardinality, 2))
  20. lst = [list(i) for i in itertools.product([0, 1], repeat=n)]
  21. # self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  22. self.all_syms = tf.convert_to_tensor(np.asarray(lst), dtype=tf.float32)
  23. def call(self, inputs, **kwargs):
  24. return tf.matmul(inputs, self.all_syms)
  25. class ExtractCentralMessage(layers.Layer):
  26. def __init__(self, messages_per_block, samples_per_symbol):
  27. """
  28. A keras layer that extracts the central message(symbol) in a block.
  29. :param messages_per_block: Total number of messages in transmission block
  30. :param samples_per_symbol: Number of samples per transmitted symbol
  31. """
  32. super(ExtractCentralMessage, self).__init__()
  33. temp_w = np.zeros((messages_per_block * samples_per_symbol, samples_per_symbol))
  34. i = np.identity(samples_per_symbol)
  35. begin = int(samples_per_symbol * ((messages_per_block - 1) / 2))
  36. end = int(samples_per_symbol * ((messages_per_block + 1) / 2))
  37. temp_w[begin:end, :] = i
  38. self.samples_per_symbol = samples_per_symbol
  39. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  40. def call(self, inputs, **kwargs):
  41. return tf.matmul(inputs, self.w)
  42. class AwgnChannel(layers.Layer):
  43. def __init__(self, rx_stddev=0.1):
  44. """
  45. A additive white gaussian noise channel model. The GaussianNoise class is utilized to prevent identical noise
  46. being applied every time the call function is called.
  47. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  48. """
  49. super(AwgnChannel, self).__init__()
  50. self.noise_layer = layers.GaussianNoise(rx_stddev)
  51. def call(self, inputs, **kwargs):
  52. return self.noise_layer.call(inputs, training=True)
  53. class DigitizationLayer(layers.Layer):
  54. def __init__(self,
  55. fs,
  56. num_of_samples,
  57. lpf_cutoff=32e9,
  58. sig_avg=0.5,
  59. enob=10):
  60. """
  61. This layer simulated the finite bandwidth of the hardware by means of a low pass filter. In addition to this,
  62. artefacts casued by quantization is modelled by the addition of white gaussian noise of a given stddev.
  63. :param fs: Sampling frequency of the simulation in Hz
  64. :param num_of_samples: Total number of samples in the input
  65. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  66. :param q_stddev: Standard deviation of quantization noise at ADC/DAC
  67. """
  68. super(DigitizationLayer, self).__init__()
  69. stddev = 3*(sig_avg**2)*(10**((-6.02*enob + 1.76)/10))
  70. self.noise_layer = layers.GaussianNoise(stddev)
  71. freq = np.fft.fftfreq(num_of_samples, d=1/fs)
  72. temp = np.ones(freq.shape)
  73. for idx, val in np.ndenumerate(freq):
  74. if np.abs(val) > lpf_cutoff:
  75. temp[idx] = 0
  76. self.lpf_multiplier = tf.convert_to_tensor(temp, dtype=tf.complex64)
  77. def call(self, inputs, **kwargs):
  78. complex_in = tf.cast(inputs, dtype=tf.complex64)
  79. val_f = tf.signal.fft(complex_in)
  80. filtered_f = tf.math.multiply(self.lpf_multiplier, val_f)
  81. filtered_t = tf.signal.ifft(filtered_f)
  82. real_t = tf.cast(filtered_t, dtype=tf.float32)
  83. noisy = self.noise_layer.call(real_t, training=True)
  84. return noisy
  85. class OpticalChannel(layers.Layer):
  86. def __init__(self,
  87. fs,
  88. num_of_samples,
  89. dispersion_factor,
  90. fiber_length,
  91. lpf_cutoff=32e9,
  92. rx_stddev=0.01,
  93. sig_avg=0.5,
  94. enob=10):
  95. """
  96. A channel model that simulates chromatic dispersion, non-linear photodiode detection, finite bandwidth of
  97. ADC/DAC as well as additive white gaussian noise in optical communication channels.
  98. :param fs: Sampling frequency of the simulation in Hz
  99. :param num_of_samples: Total number of samples in the input
  100. :param dispersion_factor: Dispersion factor in s^2/km
  101. :param fiber_length: Length of fiber to model in km
  102. :param lpf_cutoff: Cutoff frequency of LPF modelling finite bandwidth in ADC/DAC
  103. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  104. :param sig_avg: Average signal amplitude
  105. """
  106. super(OpticalChannel, self).__init__()
  107. self.rx_stddev = rx_stddev
  108. self.noise_layer = layers.GaussianNoise(self.rx_stddev)
  109. self.digitization_layer = DigitizationLayer(fs=fs,
  110. num_of_samples=num_of_samples,
  111. lpf_cutoff=lpf_cutoff,
  112. sig_avg=sig_avg,
  113. enob=enob)
  114. self.flatten_layer = layers.Flatten()
  115. self.fs = fs
  116. self.freq = tf.convert_to_tensor(np.fft.fftfreq(num_of_samples, d=1/fs), dtype=tf.complex64)
  117. self.multiplier = tf.math.exp(0.5j*dispersion_factor*fiber_length*tf.math.square(2*math.pi*self.freq))
  118. def call(self, inputs, **kwargs):
  119. # DAC LPF and noise
  120. dac_out = self.digitization_layer(inputs)
  121. # Chromatic Dispersion
  122. complex_val = tf.cast(dac_out, dtype=tf.complex64)
  123. val_f = tf.signal.fft(complex_val)
  124. disp_f = tf.math.multiply(val_f, self.multiplier)
  125. disp_t = tf.signal.ifft(disp_f)
  126. # Squared-Law Detection
  127. pd_out = tf.square(tf.abs(disp_t))
  128. # Casting back to floatx
  129. real_val = tf.cast(pd_out, dtype=tf.float32)
  130. # Adding photo-diode receiver noise
  131. rx_signal = self.noise_layer.call(real_val, training=True)
  132. # ADC LPF and noise
  133. adc_out = self.digitization_layer(rx_signal)
  134. return adc_out