end_to_end.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # End to end Autoencoder completely implemented in Keras/TF
  2. import keras
  3. import math
  4. import tensorflow as tf
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from keras import layers
  8. class ExtractCentralMessage(keras.layers.Layer):
  9. def __init__(self, neighbouring_blocks, samples_per_symbol):
  10. super(ExtractCentralMessage, self).__init__()
  11. temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol))
  12. i = np.identity(samples_per_symbol)
  13. begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2))
  14. end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2))
  15. temp_w[begin:end, :] = i
  16. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  17. def call(self, inputs):
  18. return tf.matmul(inputs, self.w)
  19. class AwgnChannel(keras.layers.Layer):
  20. def __init__(self, stddev=0.1):
  21. super(AwgnChannel, self).__init__()
  22. self.stddev = stddev
  23. self.noise_layer = layers.GaussianNoise(stddev)
  24. def call(self, inputs):
  25. serialized = layers.Flatten(inputs)
  26. return self.noise_layer.call(serialized, training=True)
  27. class OpticalChannel(keras.layers.Layer):
  28. def __init__(self, fs, stddev=0.1):
  29. super(OpticalChannel, self).__init__()
  30. self.noise_layer = layers.GaussianNoise(stddev)
  31. self.fs = fs
  32. def call(self, inputs):
  33. # TODO:
  34. # Low-pass filter & digitization noise for DAC and ADC (probably as an external layer called twice)
  35. # Serializing outputs of all blocks
  36. serialized = layers.Flatten(inputs)
  37. # Squared-Law Detection
  38. squared = tf.square(tf.abs(serialized))
  39. # Adding gaussian noise
  40. noisy = self.noise_layer.call(squared, training=True)
  41. return noisy
  42. class EndToEndAutoencoder(tf.keras.Model):
  43. def __init__(self,
  44. cardinality,
  45. samples_per_symbol,
  46. neighbouring_blocks,
  47. oversampling,
  48. channel_name='awgn'):
  49. # Number of leading/following messages
  50. if neighbouring_blocks % 2 == 0:
  51. neighbouring_blocks += 1
  52. # Oversampling rate
  53. oversampling = int(oversampling)
  54. # Transmission Channel : ['awgn', 'optical']
  55. channel_name = channel_name.strip().lower()
  56. self.encoder = tf.keras.Sequential([
  57. layers.Input(shape=(neighbouring_blocks, cardinality)),
  58. layers.Dense(2 * cardinality, activation='relu'),
  59. layers.Dense(2 * cardinality, activation='relu'),
  60. layers.Dense(samples_per_symbol),
  61. layers.ReLU(max_value=1.0)
  62. ])
  63. # Channel Model Layer
  64. if channel_name == 'optical':
  65. self.channel = OpticalChannel()
  66. elif channel_name == 'awgn':
  67. self.channel = AwgnChannel()
  68. else:
  69. raise TypeError("{} is not an accepted Channel Model.".format(channel_name))
  70. self.encoder = tf.keras.Sequential([
  71. ExtractCentralMessage(neighbouring_blocks, samples_per_symbol),
  72. layers.Dense(samples_per_symbol, activation='relu'),
  73. layers.Dense(2 * cardinality, activation='relu'),
  74. layers.Dense(2 * cardinality, activation='relu'),
  75. layers.Dense(cardinality, activation='softmax')
  76. ])
  77. def view_encoder(self):
  78. # TODO
  79. # Visualize encoder
  80. pass
  81. def call(self, x):
  82. tx = self.encoder(x)
  83. rx = self.channel(tx)
  84. y = self.decoder(rx)
  85. return y
  86. if __name__ == '__main__':
  87. # TODO
  88. # training/testing of autoencoder
  89. pass