import keras import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from sklearn.preprocessing import OneHotEncoder from keras import layers, losses class ExtractCentralMessage(layers.Layer): def __init__(self, neighbouring_blocks, samples_per_symbol): super(ExtractCentralMessage, self).__init__() temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol)) i = np.identity(samples_per_symbol) begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2)) end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2)) temp_w[begin:end, :] = i self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32) def call(self, inputs): return tf.matmul(inputs, self.w) class AwgnChannel(layers.Layer): def __init__(self, stddev=0.1): super(AwgnChannel, self).__init__() self.noise_layer = layers.GaussianNoise(stddev) self.flatten_layer = layers.Flatten() def call(self, inputs): serialized = self.flatten_layer(inputs) return self.noise_layer.call(serialized, training=True) class DigitizationLayer(layers.Layer): def __init__(self, stddev=0.1): super(DigitizationLayer, self).__init__() self.noise_layer = layers.GaussianNoise(stddev) def call(self, inputs): # TODO: # Low-pass filter (convolution with filter h(t)) return self.noise_layer.call(inputs, training=True) class OpticalChannel(layers.Layer): def __init__(self, fs, stddev=0.1): super(OpticalChannel, self).__init__() self.noise_layer = layers.GaussianNoise(stddev) self.digitization_layer = DigitizationLayer() self.flatten_layer = layers.Flatten() self.fs = fs def call(self, inputs): # Serializing outputs of all blocks serialized = self.flatten_layer(inputs) # DAC LPF and noise dac_out = self.digitization_layer(serialized) # TODO: # Chromatic Dispersion (fft -> phase shift -> ifft) # Squared-Law Detection pd_out = tf.square(tf.abs(dac_out)) # Adding photo-diode receiver noise rx_signal = self.noise_layer.call(pd_out, training=True) # ADC LPF and noise adc_out = self.digitization_layer(rx_signal) return adc_out class EndToEndAutoencoder(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, neighbouring_blocks, oversampling, channel): super(EndToEndAutoencoder, self).__init__() # Labelled M in paper self.cardinality = cardinality # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if neighbouring_blocks % 2 == 0: neighbouring_blocks += 1 self.neighbouring_blocks = neighbouring_blocks # Oversampling rate self.oversampling = int(oversampling) # Channel Model Layer if isinstance(channel, layers.Layer): self.channel = channel else: raise TypeError("Channel must be a subclass of keras.layers.layer!") # Encoding Neural Network self.encoder = tf.keras.Sequential([ layers.Input(shape=(self.neighbouring_blocks, self.cardinality)), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.samples_per_symbol), layers.ReLU(max_value=1.0) ]) # Decoding Neural Network self.decoder = tf.keras.Sequential([ ExtractCentralMessage(self.neighbouring_blocks, self.samples_per_symbol), layers.Dense(self.samples_per_symbol, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(2 * self.cardinality, activation='relu'), layers.Dense(self.cardinality, activation='softmax') ]) def generate_random_inputs(self, num_of_blocks): rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.neighbouring_blocks, 1)) # cat = np.reshape(np.arange(self.cardinality), (1, -1)) enc = OneHotEncoder(handle_unknown='ignore', sparse=False) out = enc.fit_transform(rand_int) out_arr = np.reshape(out, (num_of_blocks, self.neighbouring_blocks, self.cardinality)) mid_idx = int((self.neighbouring_blocks-1)/2) return out_arr, out_arr[:, mid_idx, :] def train(self, num_of_blocks=1e6, train_size=0.8): X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size))) self.compile(optimizer='adam', loss=losses.BinaryCrossentropy(), metrics=None, loss_weights=None, weighted_metrics=None, run_eagerly=None ) self.fit(x=X_train, y=y_train, batch_size=None, epochs=1, shuffle=True, validation_data=(X_test, y_test) ) def view_encoder(self): messages = np.zeros((self.cardinality, self.neighbouring_blocks, self.cardinality)) mid_idx = int((self.neighbouring_blocks-1)/2) idx = 0 for msg in messages: msg[mid_idx, idx] = 1 idx += 1 encoded = self.encoder(messages) return messages, encoded[:, mid_idx, :] def call(self, x): tx = self.encoder(x) rx = self.channel(tx) y = self.decoder(rx) return y if __name__ == '__main__': tx_channel = AwgnChannel(stddev=0.1) model = EndToEndAutoencoder(cardinality=8, samples_per_symbol=10, neighbouring_blocks=5, oversampling=4, channel=tx_channel) model.train() model.view_encoder() pass