layers.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """
  2. Custom Keras Layers for general use
  3. """
  4. import itertools
  5. from tensorflow.keras import layers
  6. import tensorflow as tf
  7. import numpy as np
  8. class AwgnChannel(layers.Layer):
  9. def __init__(self, rx_stddev=0.1, noise_dB=None, **kwargs):
  10. """
  11. :param rx_stddev: Standard deviation of receiver noise (due to e.g. TIA circuit)
  12. """
  13. super(AwgnChannel, self).__init__(**kwargs)
  14. if noise_dB is not None:
  15. # rx_stddev = np.sqrt(1 / (20 ** (noise_dB / 10.0)))
  16. rx_stddev = 10 ** (noise_dB / 10.0)
  17. self.noise_layer = layers.GaussianNoise(rx_stddev)
  18. def call(self, inputs, **kwargs):
  19. return self.noise_layer.call(inputs, training=True)
  20. class ScaleAndOffset(layers.Layer):
  21. """
  22. Scales and offsets a tensor
  23. """
  24. def __init__(self, scale=1, offset=0, **kwargs):
  25. super(ScaleAndOffset, self).__init__(**kwargs)
  26. self.offset = offset
  27. self.scale = scale
  28. def call(self, inputs, **kwargs):
  29. return inputs * self.scale + self.offset
  30. class BitsToSymbol(layers.Layer):
  31. def __init__(self, cardinality, **kwargs):
  32. super().__init__(**kwargs)
  33. self.cardinality = cardinality
  34. n = int(np.log(self.cardinality, 2))
  35. self.powers = tf.convert_to_tensor(
  36. np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1),
  37. dtype=tf.float32
  38. )
  39. def call(self, inputs, **kwargs):
  40. idx = tf.cast(tf.tensordot(inputs, self.powers, axes=1), dtype=tf.int32)
  41. return tf.one_hot(idx, self.cardinality)
  42. class SymbolToBits(layers.Layer):
  43. def __init__(self, cardinality, **kwargs):
  44. super().__init__(**kwargs)
  45. n = int(np.log(cardinality, 2))
  46. l = [list(i) for i in itertools.product([0, 1], repeat=n)]
  47. self.all_syms = tf.transpose(tf.convert_to_tensor(np.asarray(l), dtype=tf.float32))
  48. def call(self, inputs, **kwargs):
  49. return tf.matmul(self.all_syms, inputs)