| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- import keras
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.preprocessing import OneHotEncoder
- from keras import layers, losses
- class ExtractCentralMessage(layers.Layer):
- def __init__(self, neighbouring_blocks, samples_per_symbol):
- super(ExtractCentralMessage, self).__init__()
- temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol))
- i = np.identity(samples_per_symbol)
- begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2))
- end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2))
- temp_w[begin:end, :] = i
- self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
- def call(self, inputs):
- return tf.matmul(inputs, self.w)
- class AwgnChannel(layers.Layer):
- def __init__(self, stddev=0.1):
- super(AwgnChannel, self).__init__()
- self.noise_layer = layers.GaussianNoise(stddev)
- self.flatten_layer = layers.Flatten()
- def call(self, inputs):
- serialized = self.flatten_layer(inputs)
- return self.noise_layer.call(serialized, training=True)
- class DigitizationLayer(layers.Layer):
- def __init__(self, stddev=0.1):
- super(DigitizationLayer, self).__init__()
- self.noise_layer = layers.GaussianNoise(stddev)
- def call(self, inputs):
- # TODO:
- # Low-pass filter (convolution with filter h(t))
- return self.noise_layer.call(inputs, training=True)
- class OpticalChannel(layers.Layer):
- def __init__(self, fs, stddev=0.1):
- super(OpticalChannel, self).__init__()
- self.noise_layer = layers.GaussianNoise(stddev)
- self.digitization_layer = DigitizationLayer()
- self.flatten_layer = layers.Flatten()
- self.fs = fs
- def call(self, inputs):
- # Serializing outputs of all blocks
- serialized = self.flatten_layer(inputs)
- # DAC LPF and noise
- dac_out = self.digitization_layer(serialized)
- # TODO:
- # Chromatic Dispersion (fft -> phase shift -> ifft)
- # Squared-Law Detection
- pd_out = tf.square(tf.abs(dac_out))
- # Adding photo-diode receiver noise
- rx_signal = self.noise_layer.call(pd_out, training=True)
- # ADC LPF and noise
- adc_out = self.digitization_layer(rx_signal)
- return adc_out
- class EndToEndAutoencoder(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- neighbouring_blocks,
- oversampling,
- channel):
- super(EndToEndAutoencoder, self).__init__()
- # Labelled M in paper
- self.cardinality = cardinality
- # Labelled n in paper
- self.samples_per_symbol = samples_per_symbol
- # Labelled N in paper
- if neighbouring_blocks % 2 == 0:
- neighbouring_blocks += 1
- self.neighbouring_blocks = neighbouring_blocks
- # Oversampling rate
- self.oversampling = int(oversampling)
- # Channel Model Layer
- if isinstance(channel, layers.Layer):
- self.channel = channel
- else:
- raise TypeError("Channel must be a subclass of keras.layers.layer!")
- # Encoding Neural Network
- self.encoder = tf.keras.Sequential([
- layers.Input(shape=(self.neighbouring_blocks, self.cardinality)),
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(self.samples_per_symbol),
- layers.ReLU(max_value=1.0)
- ])
- # Decoding Neural Network
- self.decoder = tf.keras.Sequential([
- ExtractCentralMessage(self.neighbouring_blocks, self.samples_per_symbol),
- layers.Dense(self.samples_per_symbol, activation='relu'),
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(self.cardinality, activation='softmax')
- ])
- def generate_random_inputs(self, num_of_blocks):
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.neighbouring_blocks, 1))
- # cat = np.reshape(np.arange(self.cardinality), (1, -1))
- enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
- out = enc.fit_transform(rand_int)
- out_arr = np.reshape(out, (num_of_blocks, self.neighbouring_blocks, self.cardinality))
- mid_idx = int((self.neighbouring_blocks-1)/2)
- return out_arr, out_arr[:, mid_idx, :]
- def train(self, num_of_blocks=1e6, train_size=0.8):
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
- self.compile(optimizer='adam',
- loss=losses.BinaryCrossentropy(),
- metrics=None,
- loss_weights=None,
- weighted_metrics=None,
- run_eagerly=None
- )
- self.fit(x=X_train,
- y=y_train,
- batch_size=None,
- epochs=1,
- shuffle=True,
- validation_data=(X_test, y_test)
- )
- def view_encoder(self):
- messages = np.zeros((self.cardinality, self.neighbouring_blocks, self.cardinality))
- mid_idx = int((self.neighbouring_blocks-1)/2)
- idx = 0
- for msg in messages:
- msg[mid_idx, idx] = 1
- idx += 1
- encoded = self.encoder(messages)
- return messages, encoded[:, mid_idx, :]
- def call(self, x):
- tx = self.encoder(x)
- rx = self.channel(tx)
- y = self.decoder(rx)
- return y
- if __name__ == '__main__':
- tx_channel = AwgnChannel(stddev=0.1)
- model = EndToEndAutoencoder(cardinality=8,
- samples_per_symbol=10,
- neighbouring_blocks=5,
- oversampling=4,
- channel=tx_channel)
- model.train()
- model.view_encoder()
- pass
|