end_to_end.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import keras
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.preprocessing import OneHotEncoder
  6. from keras import layers, losses
  7. class ExtractCentralMessage(layers.Layer):
  8. def __init__(self, neighbouring_blocks, samples_per_symbol):
  9. super(ExtractCentralMessage, self).__init__()
  10. temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol))
  11. i = np.identity(samples_per_symbol)
  12. begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2))
  13. end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2))
  14. temp_w[begin:end, :] = i
  15. self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
  16. def call(self, inputs):
  17. return tf.matmul(inputs, self.w)
  18. class AwgnChannel(layers.Layer):
  19. def __init__(self, stddev=0.1):
  20. super(AwgnChannel, self).__init__()
  21. self.noise_layer = layers.GaussianNoise(stddev)
  22. self.flatten_layer = layers.Flatten()
  23. def call(self, inputs):
  24. serialized = self.flatten_layer(inputs)
  25. return self.noise_layer.call(serialized, training=True)
  26. class DigitizationLayer(layers.Layer):
  27. def __init__(self, stddev=0.1):
  28. super(DigitizationLayer, self).__init__()
  29. self.noise_layer = layers.GaussianNoise(stddev)
  30. def call(self, inputs):
  31. # TODO:
  32. # Low-pass filter (convolution with filter h(t))
  33. return self.noise_layer.call(inputs, training=True)
  34. class OpticalChannel(layers.Layer):
  35. def __init__(self, fs, stddev=0.1):
  36. super(OpticalChannel, self).__init__()
  37. self.noise_layer = layers.GaussianNoise(stddev)
  38. self.digitization_layer = DigitizationLayer()
  39. self.flatten_layer = layers.Flatten()
  40. self.fs = fs
  41. def call(self, inputs):
  42. # Serializing outputs of all blocks
  43. serialized = self.flatten_layer(inputs)
  44. # DAC LPF and noise
  45. dac_out = self.digitization_layer(serialized)
  46. # TODO:
  47. # Chromatic Dispersion (fft -> phase shift -> ifft)
  48. # Squared-Law Detection
  49. pd_out = tf.square(tf.abs(dac_out))
  50. # Adding photo-diode receiver noise
  51. rx_signal = self.noise_layer.call(pd_out, training=True)
  52. # ADC LPF and noise
  53. adc_out = self.digitization_layer(rx_signal)
  54. return adc_out
  55. class EndToEndAutoencoder(tf.keras.Model):
  56. def __init__(self,
  57. cardinality,
  58. samples_per_symbol,
  59. neighbouring_blocks,
  60. oversampling,
  61. channel):
  62. super(EndToEndAutoencoder, self).__init__()
  63. # Labelled M in paper
  64. self.cardinality = cardinality
  65. # Labelled n in paper
  66. self.samples_per_symbol = samples_per_symbol
  67. # Labelled N in paper
  68. if neighbouring_blocks % 2 == 0:
  69. neighbouring_blocks += 1
  70. self.neighbouring_blocks = neighbouring_blocks
  71. # Oversampling rate
  72. self.oversampling = int(oversampling)
  73. # Channel Model Layer
  74. if isinstance(channel, layers.Layer):
  75. self.channel = channel
  76. else:
  77. raise TypeError("Channel must be a subclass of keras.layers.layer!")
  78. # Encoding Neural Network
  79. self.encoder = tf.keras.Sequential([
  80. layers.Input(shape=(self.neighbouring_blocks, self.cardinality)),
  81. layers.Dense(2 * self.cardinality, activation='relu'),
  82. layers.Dense(2 * self.cardinality, activation='relu'),
  83. layers.Dense(self.samples_per_symbol),
  84. layers.ReLU(max_value=1.0)
  85. ])
  86. # Decoding Neural Network
  87. self.decoder = tf.keras.Sequential([
  88. ExtractCentralMessage(self.neighbouring_blocks, self.samples_per_symbol),
  89. layers.Dense(self.samples_per_symbol, activation='relu'),
  90. layers.Dense(2 * self.cardinality, activation='relu'),
  91. layers.Dense(2 * self.cardinality, activation='relu'),
  92. layers.Dense(self.cardinality, activation='softmax')
  93. ])
  94. def generate_random_inputs(self, num_of_blocks):
  95. rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.neighbouring_blocks, 1))
  96. # cat = np.reshape(np.arange(self.cardinality), (1, -1))
  97. enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
  98. out = enc.fit_transform(rand_int)
  99. out_arr = np.reshape(out, (num_of_blocks, self.neighbouring_blocks, self.cardinality))
  100. mid_idx = int((self.neighbouring_blocks-1)/2)
  101. return out_arr, out_arr[:, mid_idx, :]
  102. def train(self, num_of_blocks=1e6, train_size=0.8):
  103. X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
  104. X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
  105. self.compile(optimizer='adam',
  106. loss=losses.BinaryCrossentropy(),
  107. metrics=None,
  108. loss_weights=None,
  109. weighted_metrics=None,
  110. run_eagerly=None
  111. )
  112. self.fit(x=X_train,
  113. y=y_train,
  114. batch_size=None,
  115. epochs=1,
  116. shuffle=True,
  117. validation_data=(X_test, y_test)
  118. )
  119. def view_encoder(self):
  120. messages = np.zeros((self.cardinality, self.neighbouring_blocks, self.cardinality))
  121. mid_idx = int((self.neighbouring_blocks-1)/2)
  122. idx = 0
  123. for msg in messages:
  124. msg[mid_idx, idx] = 1
  125. idx += 1
  126. encoded = self.encoder(messages)
  127. return messages, encoded[:, mid_idx, :]
  128. def call(self, x):
  129. tx = self.encoder(x)
  130. rx = self.channel(tx)
  131. y = self.decoder(rx)
  132. return y
  133. if __name__ == '__main__':
  134. tx_channel = AwgnChannel(stddev=0.1)
  135. model = EndToEndAutoencoder(cardinality=8,
  136. samples_per_symbol=10,
  137. neighbouring_blocks=5,
  138. oversampling=4,
  139. channel=tx_channel)
  140. model.train()
  141. model.view_encoder()
  142. pass