|
|
@@ -6,9 +6,7 @@ import matplotlib.pyplot as plt
|
|
|
from sklearn.metrics import accuracy_score
|
|
|
from sklearn.preprocessing import OneHotEncoder
|
|
|
from tensorflow.keras import layers, losses
|
|
|
-from tensorflow.keras import backend as K
|
|
|
from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
|
|
|
-import itertools
|
|
|
|
|
|
|
|
|
class EndToEndAutoencoder(tf.keras.Model):
|
|
|
@@ -17,7 +15,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
samples_per_symbol,
|
|
|
messages_per_block,
|
|
|
channel,
|
|
|
- bit_mapping=False):
|
|
|
+ custom_loss_fn=False):
|
|
|
"""
|
|
|
The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
|
|
|
of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
|
|
|
@@ -37,7 +35,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
# Labelled n in paper
|
|
|
self.samples_per_symbol = samples_per_symbol
|
|
|
|
|
|
- # Labelled N in paper
|
|
|
+ # Labelled N in paper - conditional +=1 to ensure odd value
|
|
|
if messages_per_block % 2 == 0:
|
|
|
messages_per_block += 1
|
|
|
self.messages_per_block = messages_per_block
|
|
|
@@ -50,10 +48,10 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
|
|
|
], name="channel_model")
|
|
|
else:
|
|
|
- raise TypeError("Channel must be a subclass of keras.layers.layer!")
|
|
|
+ raise TypeError("Channel must be a subclass of \"tensorflow.keras.layers.layer\"!")
|
|
|
|
|
|
# Boolean identifying if bit mapping is to be learnt
|
|
|
- self.bit_mapping = bit_mapping
|
|
|
+ self.custom_loss_fn = custom_loss_fn
|
|
|
|
|
|
# other parameters/metrics
|
|
|
self.symbol_error_rate = None
|
|
|
@@ -64,46 +62,24 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
leaky_relu_alpha = 0
|
|
|
relu_clip_val = 1.0
|
|
|
|
|
|
- # Layer configuration for the case when bit mapping is to be learnt
|
|
|
- if self.bit_mapping:
|
|
|
- encoding_layers = [
|
|
|
- layers.Input(shape=(self.messages_per_block, self.bits_per_symbol)),
|
|
|
- BitsToSymbols(self.cardinality),
|
|
|
- layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
- # layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- # layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
- # layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
- layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
- layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
- ]
|
|
|
- decoding_layers = [
|
|
|
- layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
- # layers.Dense(2 * self.cardinality),
|
|
|
- # layers.LeakyReLU(alpha=0.01),
|
|
|
- layers.Dense(self.bits_per_symbol, activation='sigmoid')
|
|
|
- ]
|
|
|
-
|
|
|
- # layer configuration for the case when only symbol mapping is to be learnt
|
|
|
- else:
|
|
|
- encoding_layers = [
|
|
|
- layers.Input(shape=(self.messages_per_block, self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
- layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
- layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
- layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
- # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
- # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
- ]
|
|
|
- decoding_layers = [
|
|
|
- layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
- layers.Dense(2 * self.cardinality),
|
|
|
- layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
- layers.Dense(self.cardinality, activation='softmax')
|
|
|
- ]
|
|
|
+ # layer configuration
|
|
|
+ encoding_layers = [
|
|
|
+ layers.Input(shape=(self.messages_per_block, self.cardinality)),
|
|
|
+ layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
+ layers.TimeDistributed(layers.Dense(2 * self.cardinality)),
|
|
|
+ layers.TimeDistributed(layers.LeakyReLU(alpha=leaky_relu_alpha)),
|
|
|
+ layers.TimeDistributed(layers.Dense(self.samples_per_symbol, activation='sigmoid')),
|
|
|
+ # layers.TimeDistributed(layers.Dense(self.samples_per_symbol)),
|
|
|
+ # layers.TimeDistributed(layers.ReLU(max_value=relu_clip_val))
|
|
|
+ ]
|
|
|
+ decoding_layers = [
|
|
|
+ layers.Dense(2 * self.cardinality),
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
+ layers.Dense(2 * self.cardinality),
|
|
|
+ layers.LeakyReLU(alpha=leaky_relu_alpha),
|
|
|
+ layers.Dense(self.cardinality, activation='softmax')
|
|
|
+ ]
|
|
|
|
|
|
# Encoding Neural Network
|
|
|
self.encoder = tf.keras.Sequential([
|
|
|
@@ -141,26 +117,15 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
|
- if self.bit_mapping:
|
|
|
- rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
|
|
|
-
|
|
|
- out = rand_int
|
|
|
-
|
|
|
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
|
|
|
-
|
|
|
- if return_vals:
|
|
|
- return out_arr, out_arr, out_arr[:, mid_idx, :]
|
|
|
-
|
|
|
- else:
|
|
|
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
+ rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
|
|
|
|
|
|
- out = enc.fit_transform(rand_int)
|
|
|
+ out = enc.fit_transform(rand_int)
|
|
|
|
|
|
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
|
|
|
+ out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
|
|
|
|
|
|
- if return_vals:
|
|
|
- out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
|
|
|
- return out_val, out_arr, out_arr[:, mid_idx, :]
|
|
|
+ if return_vals:
|
|
|
+ out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
|
|
|
+ return out_val, out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
@@ -178,21 +143,10 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
|
|
|
- # TODO: Investigate different optimizers (with different learning rates and other parameters)
|
|
|
- # SGD
|
|
|
- # RMSprop
|
|
|
- # Adam
|
|
|
- # Adadelta
|
|
|
- # Adagrad
|
|
|
- # Adamax
|
|
|
- # Nadam
|
|
|
- # Ftrl
|
|
|
-
|
|
|
- if self.bit_mapping:
|
|
|
- loss_fn = losses.BinaryCrossentropy()
|
|
|
- else:
|
|
|
- # loss_fn = losses.CategoricalCrossentropy()
|
|
|
+ if self.custom_loss_fn:
|
|
|
loss_fn = self.cost
|
|
|
+ else:
|
|
|
+ loss_fn = losses.CategoricalCrossentropy()
|
|
|
|
|
|
self.compile(optimizer=opt,
|
|
|
loss=loss_fn,
|
|
|
@@ -202,13 +156,13 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
run_eagerly=False
|
|
|
)
|
|
|
|
|
|
- history = self.fit(x=X_train,
|
|
|
- y=y_train,
|
|
|
- batch_size=batch_size,
|
|
|
- epochs=epochs,
|
|
|
- shuffle=True,
|
|
|
- validation_data=(X_test, y_test)
|
|
|
- )
|
|
|
+ self.fit(x=X_train,
|
|
|
+ y=y_train,
|
|
|
+ batch_size=batch_size,
|
|
|
+ epochs=epochs,
|
|
|
+ shuffle=True,
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
+ )
|
|
|
|
|
|
def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
|
|
|
X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
|
|
|
@@ -278,23 +232,13 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
mid_idx = int((self.messages_per_block - 1) / 2)
|
|
|
|
|
|
- if self.bit_mapping:
|
|
|
- messages = np.zeros((self.cardinality, self.messages_per_block, self.bits_per_symbol))
|
|
|
- lst = [list(i) for i in itertools.product([0, 1], repeat=self.bits_per_symbol)]
|
|
|
-
|
|
|
- idx = 0
|
|
|
- for msg in messages:
|
|
|
- msg[mid_idx] = lst[idx]
|
|
|
- idx += 1
|
|
|
+ # Generate inputs for encoder
|
|
|
+ messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
|
|
|
- else:
|
|
|
- # Generate inputs for encoder
|
|
|
- messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
|
|
|
-
|
|
|
- idx = 0
|
|
|
- for msg in messages:
|
|
|
- msg[mid_idx, idx] = 1
|
|
|
- idx += 1
|
|
|
+ idx = 0
|
|
|
+ for msg in messages:
|
|
|
+ msg[mid_idx, idx] = 1
|
|
|
+ idx += 1
|
|
|
|
|
|
# Pass input through encoder and select middle messages
|
|
|
encoded = self.encoder(messages)
|
|
|
@@ -387,7 +331,6 @@ DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
FIBER_LENGTH = 50
|
|
|
FIBER_LENGTH_STDDEV = 5
|
|
|
|
|
|
-
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
stddevs = [0, 1, 5, 10]
|
|
|
@@ -408,7 +351,7 @@ if __name__ == '__main__':
|
|
|
samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
channel=optical_channel,
|
|
|
- bit_mapping=False)
|
|
|
+ custom_loss_fn=True)
|
|
|
|
|
|
print(ae_model.snr)
|
|
|
|