| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- import math
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.preprocessing import OneHotEncoder
- from tensorflow.keras import layers, losses
- from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer
- class EndToEndAutoencoder(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- messages_per_block,
- channel,
- recurrent=False):
- """
- The autoencoder that aims to find a encoding of the input messages. It should be noted that a "block" consists
- of multiple "messages" to introduce memory into the simulation as this is essential for modelling inter-symbol
- interference. The autoencoder architecture was heavily influenced by IEEE 8433895.
- :param cardinality: Number of different messages. Chosen such that each message encodes log_2(cardinality) bits
- :param samples_per_symbol: Number of samples per transmitted symbol
- :param messages_per_block: Total number of messages in transmission block
- :param channel: Channel Layer object. Must be a subclass of keras.layers.Layer with an implemented forward pass
- """
- super(EndToEndAutoencoder, self).__init__()
- # Labelled M in paper
- self.cardinality = cardinality
- # Labelled n in paper
- self.samples_per_symbol = samples_per_symbol
- # Labelled N in paper
- if messages_per_block % 2 == 0:
- messages_per_block += 1
- self.messages_per_block = messages_per_block
- # Channel Model Layer
- if isinstance(channel, layers.Layer):
- self.channel = tf.keras.Sequential([
- layers.Flatten(),
- channel,
- ExtractCentralMessage(self.messages_per_block, self.samples_per_symbol)
- ], name="channel_model")
- else:
- raise TypeError("Channel must be a subclass of keras.layers.layer!")
- self.recurrent = recurrent
- if recurrent:
- input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality), batch_size=1)
- # encoding_layers = [
- # layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
- # layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
- # ]
- decoding_layers = [
- layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True),
- layers.LSTM(2 * self.cardinality, activation='relu', return_sequences=True, stateful=True)
- ]
- else:
- input_layer = layers.Input(shape=(self.messages_per_block, self.cardinality))
- decoding_layers = [
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(2 * self.cardinality, activation='relu')
- ]
- # Encoding Neural Network
- self.encoder = tf.keras.Sequential([
- input_layer,
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(2 * self.cardinality, activation='relu'),
- layers.Dense(self.samples_per_symbol),
- layers.ReLU(max_value=1.0)
- ], name="encoding_model")
- # Decoding Neural Network
- self.decoder = tf.keras.Sequential([
- layers.Dense(self.samples_per_symbol, activation='relu'),
- *decoding_layers,
- layers.Dense(self.cardinality, activation='softmax')
- ], name="decoding_model")
- def generate_random_inputs(self, num_of_blocks, return_vals=False):
- """
- A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
- :param num_of_blocks: Number of blocks to generate. A block contains multiple messages to be transmitted in
- consecutively to model ISI. The central message in a block is returned as the label for training.
- :param return_vals: If true, the raw decimal values of the input sequence will be returned
- """
- cat = [np.arange(self.cardinality)]
- enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
- mid_idx = int((self.messages_per_block - 1) / 2)
- if self.recurrent and not return_vals:
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks+self.messages_per_block-1, 1))
- rand_enc = enc.fit_transform(rand_int)
- out = []
- for i in range(num_of_blocks):
- out.append(rand_enc[i:i+self.messages_per_block])
- out = np.array(out)
- return out, out[:, mid_idx, :]
- else:
- rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
- out = enc.fit_transform(rand_int)
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
- if return_vals:
- out_val = np.reshape(rand_int, (num_of_blocks, self.messages_per_block, 1))
- return out_val, out_arr, out_arr[:, mid_idx, :]
- return out_arr, out_arr[:, mid_idx, :]
- def train(self, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
- """
- Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
- :param num_of_blocks: Number of blocks to generate for training. Analogous to the dataset size.
- :param batch_size: Number of samples to consider on each update iteration of the optimization algorithm
- :param train_size: Float less than 1 representing the proportion of the dataset to use for training
- :param lr: The learning rate of the optimizer. Defines how quickly the algorithm converges
- """
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks*train_size))
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks*(1-train_size)))
- opt = tf.keras.optimizers.Adam(learning_rate=lr)
- self.compile(optimizer=opt,
- loss=losses.BinaryCrossentropy(),
- metrics=['accuracy'],
- loss_weights=None,
- weighted_metrics=None,
- run_eagerly=False
- )
- shuffle = True
- if self.recurrent and batch_size is None:
- # If recurrent layers are present in the model then the training data is considered one at a time without
- # shuffling of the data. This preserves order in the data.
- batch_size = 1
- shuffle = False
- self.fit(x=X_train,
- y=y_train,
- batch_size=batch_size,
- epochs=1,
- shuffle=shuffle,
- validation_data=(X_test, y_test)
- )
- def view_encoder(self):
- '''
- A method that views the learnt encoder for each distint message. This is displayed as a plot with asubplot for
- each image.
- '''
- # Generate inputs for encoder
- messages = np.zeros((self.cardinality, self.messages_per_block, self.cardinality))
- mid_idx = int((self.messages_per_block-1)/2)
- idx = 0
- for msg in messages:
- msg[mid_idx, idx] = 1
- idx += 1
- # Pass input through encoder and select middle messages
- encoded = self.encoder(messages)
- enc_messages = encoded[:, mid_idx, :]
- # Compute subplot grid layout
- i = 0
- while 2**i < self.cardinality**0.5:
- i += 1
- num_x = int(2**i)
- num_y = int(self.cardinality / num_x)
- # Plot all symbols
- fig, axs = plt.subplots(num_y, num_x, figsize=(2.5*num_x, 2*num_y))
- t = np.arange(self.samples_per_symbol)
- if isinstance(self.channel.layers[1], OpticalChannel):
- t = t/self.channel.layers[1].fs
- sym_idx = 0
- for y in range(num_y):
- for x in range(num_x):
- axs[y, x].plot(t, enc_messages[sym_idx], 'x')
- axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
- sym_idx += 1
- for ax in axs.flat:
- ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
- for ax in axs.flat:
- ax.label_outer()
- plt.show()
- pass
- def view_sample_block(self):
- '''
- Generates a random string of input message and encodes them. In addition to this, the output is passed through
- digitization layer without any quantization noise for the low pass filtering.
- '''
- # Generate a random block of messages
- val, inp, _ = self.generate_random_inputs(num_of_blocks=1, return_vals=True)
- # Encode and flatten the messages
- enc = self.encoder(inp)
- flat_enc = layers.Flatten()(enc)
- # Instantiate LPF layer
- lpf = DigitizationLayer(fs=self.channel.layers[1].fs,
- num_of_samples=self.messages_per_block*self.samples_per_symbol,
- q_stddev=0)
- # Apply LPF
- lpf_out = lpf(flat_enc)
- # Time axis
- t = np.arange(self.messages_per_block*self.samples_per_symbol)
- if isinstance(self.channel.layers[1], OpticalChannel):
- t = t / self.channel.layers[1].fs
- # Plot the concatenated symbols before and after LPF
- plt.figure(figsize=(2*self.messages_per_block, 6))
- for i in range(1, self.messages_per_block):
- plt.axvline(x=t[i*self.samples_per_symbol], color='black')
- plt.plot(t, flat_enc.numpy().T, 'x')
- plt.plot(t, lpf_out.numpy().T)
- plt.ylim((0, 1))
- plt.xlim((t.min(), t.max()))
- plt.title(str(val[0, :, 0]))
- plt.show()
- pass
- def call(self, inputs, training=None, mask=None):
- tx = self.encoder(inputs)
- rx = self.channel(tx)
- outputs = self.decoder(rx)
- return outputs
- if __name__ == '__main__':
- SAMPLING_FREQUENCY = 336e9
- CARDINALITY = 32
- SAMPLES_PER_SYMBOL = 24
- MESSAGES_PER_BLOCK = 9
- DISPERSION_FACTOR = -21.7 * 1e-24
- FIBER_LENGTH = 50
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK*SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=FIBER_LENGTH)
- ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel,
- recurrent=True)
- ae_model.train(num_of_blocks=1e5)
- ae_model.view_encoder()
- ae_model.view_sample_block()
- ae_model.summary()
- pass
|