|
|
@@ -3,12 +3,14 @@ import math
|
|
|
import tensorflow as tf
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
+from sklearn.metrics import accuracy_score
|
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
|
from tensorflow.keras import layers, losses
|
|
|
from tensorflow.keras import backend as K
|
|
|
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols
|
|
|
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
|
|
|
import itertools
|
|
|
|
|
|
+
|
|
|
class EndToEndAutoencoder(tf.keras.Model):
|
|
|
def __init__(self,
|
|
|
cardinality,
|
|
|
@@ -53,27 +55,29 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
# Boolean identifying if bit mapping is to be learnt
|
|
|
self.bit_mapping = bit_mapping
|
|
|
|
|
|
+ # Model Hyper-parameters
|
|
|
+ leaky_relu_alpha = 0
|
|
|
+ relu_clip_val = 1.0
|
|
|
+
|
|
|
# Layer configuration for the case when bit mapping is to be learnt
|
|
|
if self.bit_mapping:
|
|
|
encoding_layers = [
|
|
|
layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
|
|
|
BitsToSymbols(self.cardinality),
|
|
|
layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
# layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- # layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
|
|
|
+ # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
# layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
- layers.TimeDistributed(layers.ReLU(max_value=1.0))
|
|
|
+ layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
]
|
|
|
decoding_layers = [
|
|
|
layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=0.01),
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
# layers.Dense(2 * self.cardinality),
|
|
|
# layers.LeakyReLU(alpha=0.01),
|
|
|
- layers.Dense(self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=0.01),
|
|
|
- layers.Dense(self.bits_per_symbol, activation='sigmoid'),
|
|
|
+ layers.Dense(self.bits_per_symbol, activation='sigmoid')
|
|
|
]
|
|
|
|
|
|
# layer configuration for the case when only symbol mapping is to be learnt
|
|
|
@@ -81,21 +85,19 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
encoding_layers = [
|
|
|
layers.Input(shape=(self.messages_per_block, self.cardinality)),
|
|
|
layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=0.01)),
|
|
|
- # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
- layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
- layers.TimeDistributed(layers.ReLU(max_value=1.0))
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
+ layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
+ # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
+ # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
]
|
|
|
decoding_layers = [
|
|
|
layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=0.01),
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=0.01),
|
|
|
- layers.Dense(self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=0.01),
|
|
|
- layers.Dense(self.bits_per_symbol, activation='sigmoid'),
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
+ layers.Dense(self.cardinality, activation='softmax')
|
|
|
]
|
|
|
|
|
|
# Encoding Neural Network
|
|
|
@@ -130,7 +132,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
|
|
|
|
|
|
if return_vals:
|
|
|
- #TODO
|
|
|
+ return out_arr, out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
else:
|
|
|
rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
@@ -145,7 +147,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
- def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-2):
|
|
|
+ def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
"""
|
|
|
Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
|
|
|
|
|
|
@@ -159,8 +161,23 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
|
|
|
+ # TODO: Investigate different optimizers (with different learning rates and other parameters)
|
|
|
+ # SGD
|
|
|
+ # RMSprop
|
|
|
+ # Adam
|
|
|
+ # Adadelta
|
|
|
+ # Adagrad
|
|
|
+ # Adamax
|
|
|
+ # Nadam
|
|
|
+ # Ftrl
|
|
|
+
|
|
|
+ if self.bit_mapping:
|
|
|
+ loss_fn = losses.BinaryCrossentropy()
|
|
|
+ else:
|
|
|
+ loss_fn = losses.CategoricalCrossentropy()
|
|
|
+
|
|
|
self.compile(optimizer=opt,
|
|
|
- loss=losses.BinaryCrossentropy(),
|
|
|
+ loss=loss_fn,
|
|
|
metrics=['accuracy'],
|
|
|
loss_weights=None,
|
|
|
weighted_metrics=None,
|
|
|
@@ -168,33 +185,65 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
)
|
|
|
|
|
|
history = self.fit(x=X_train,
|
|
|
- y=y_train,
|
|
|
- batch_size=batch_size,
|
|
|
- epochs=1,
|
|
|
- shuffle=True,
|
|
|
- validation_data=(X_test, y_test)
|
|
|
- )
|
|
|
+ y=y_train,
|
|
|
+ batch_size=batch_size,
|
|
|
+ epochs=epochs,
|
|
|
+ shuffle=True,
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
+ )
|
|
|
|
|
|
plt.plot(history.history['accuracy'])
|
|
|
plt.plot(history.history['val_accuracy'])
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
+ def test(self, num_of_blocks=1e3):
|
|
|
+ X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
|
|
|
+
|
|
|
+ y_out = self.call(X_test)
|
|
|
+
|
|
|
+ y_pred = tf.argmax(y_out, axis=1)
|
|
|
+ y_true = tf.argmax(y_test, axis=1)
|
|
|
+
|
|
|
+ symbol_error_rate = 1 - accuracy_score(y_true, y_pred)
|
|
|
+
|
|
|
+ lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
|
|
|
+
|
|
|
+ bits_pred = SymbolsToBits(self.cardinality)(tf.one_hot(y_pred, self.cardinality)).numpy().flatten()
|
|
|
+ bits_true = SymbolsToBits(self.cardinality)(y_test).numpy().flatten()
|
|
|
+
|
|
|
+ bit_error_rate = 1 - accuracy_score(bits_true, bits_pred)
|
|
|
+
|
|
|
+ print("SYMBOL ERROR RATE: {}".format(symbol_error_rate))
|
|
|
+ print("BIT ERROR RATE: {}".format(bit_error_rate))
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
def view_encoder(self):
|
|
|
'''
|
|
|
A method that views the learnt encoder for each distint message. This is displayed as a plot with a subplot for
|
|
|
each message/symbol.
|
|
|
'''
|
|
|
|
|
|
- # Generate inputs for encoder
|
|
|
- messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
-
|
|
|
mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
|
- idx = 0
|
|
|
- for msg in messages:
|
|
|
- msg[mid_idx, idx] = 1
|
|
|
- idx += 1
|
|
|
+ if self.bit_mapping:
|
|
|
+ messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
|
|
|
+ lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
|
|
|
+
|
|
|
+ idx = 0
|
|
|
+ for msg in messages:
|
|
|
+ msg[mid_idx] = lst[idx]
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ else:
|
|
|
+ # Generate inputs for encoder
|
|
|
+ messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
+
|
|
|
+ idx = 0
|
|
|
+ for msg in messages:
|
|
|
+ msg[mid_idx, idx] = 1
|
|
|
+ idx += 1
|
|
|
|
|
|
# Pass input through encoder and select middle messages
|
|
|
encoded = self.encoder(messages)
|
|
|
@@ -218,7 +267,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
sym_idx = 0
|
|
|
for y in range(num_y):
|
|
|
for x in range(num_x):
|
|
|
- axs[y, x].plot(t, enc_messages[sym_idx], 'x')
|
|
|
+ axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
|
|
|
axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
|
|
|
sym_idx += 1
|
|
|
|
|
|
@@ -242,11 +291,12 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
# Encode and flatten the messages
|
|
|
enc = self.encoder(inp)
|
|
|
flat_enc = layers.Flatten()(enc)
|
|
|
+ chan_out = self.channel.layers[1](flat_enc)
|
|
|
|
|
|
# Instantiate LPF layer
|
|
|
lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
|
|
|
num_of_samples=self.messages_per_block * self.samples_per_symbol,
|
|
|
- q_stddev=0)
|
|
|
+ sig_avg=0)
|
|
|
|
|
|
# Apply LPF
|
|
|
lpf_out = lpf(flat_enc)
|
|
|
@@ -264,6 +314,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
plt.plot(t, flat_enc.numpy().T, 'x')
|
|
|
plt.plot(t, lpf_out.numpy().T)
|
|
|
+ plt.plot(t, chan_out.numpy().flatten())
|
|
|
plt.ylim((0, 1))
|
|
|
plt.xlim((t.min(), t.max()))
|
|
|
plt.title(str(val[0, :, 0]))
|
|
|
@@ -276,16 +327,15 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
outputs = self.decoder(rx)
|
|
|
return outputs
|
|
|
|
|
|
+
|
|
|
SAMPLING_FREQUENCY = 336e9
|
|
|
CARDINALITY = 32
|
|
|
-SAMPLES_PER_SYMBOL = 24
|
|
|
+SAMPLES_PER_SYMBOL = 32
|
|
|
MESSAGES_PER_BLOCK = 9
|
|
|
DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
-FIBER_LENGTH = 50
|
|
|
-
|
|
|
+FIBER_LENGTH = 0
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
-
|
|
|
optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
dispersion_factor=DISPERSION_FACTOR,
|
|
|
@@ -295,11 +345,12 @@ if __name__ == '__main__':
|
|
|
samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
channel=optical_channel,
|
|
|
- bit_mapping=True)
|
|
|
+ bit_mapping=False)
|
|
|
|
|
|
- ae_model.train(num_of_blocks=1e6)
|
|
|
- # ae_model.view_encoder()
|
|
|
- # ae_model.view_sample_block()
|
|
|
+ ae_model.train(num_of_blocks=1e5, epochs=5)
|
|
|
+ ae_model.test()
|
|
|
+ ae_model.view_encoder()
|
|
|
+ ae_model.view_sample_block()
|
|
|
# ae_model.summary()
|
|
|
ae_model.encoder.summary()
|
|
|
ae_model.channel.summary()
|