| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- # End to end Autoencoder completely implemented in Keras/TF
- import keras
- import math
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- from keras import layers
- class ExtractCentralMessage(keras.layers.Layer):
- def __init__(self, neighbouring_blocks, samples_per_symbol):
- super(ExtractCentralMessage, self).__init__()
- temp_w = np.zeros((neighbouring_blocks * samples_per_symbol, samples_per_symbol))
- i = np.identity(samples_per_symbol)
- begin = int(samples_per_symbol * ((neighbouring_blocks - 1) / 2))
- end = int(samples_per_symbol * ((neighbouring_blocks + 1) / 2))
- temp_w[begin:end, :] = i
- self.w = tf.convert_to_tensor(temp_w, dtype=tf.float32)
- def call(self, inputs):
- return tf.matmul(inputs, self.w)
- class AwgnChannel(keras.layers.Layer):
- def __init__(self, stddev=0.1):
- super(AwgnChannel, self).__init__()
- self.stddev = stddev
- self.noise_layer = layers.GaussianNoise(stddev)
- def call(self, inputs):
- serialized = layers.Flatten(inputs)
- return self.noise_layer.call(serialized, training=True)
- class OpticalChannel(keras.layers.Layer):
- def __init__(self, fs, stddev=0.1):
- super(OpticalChannel, self).__init__()
- self.noise_layer = layers.GaussianNoise(stddev)
- self.fs = fs
- def call(self, inputs):
- # TODO:
- # Low-pass filter & digitization noise for DAC and ADC (probably as an external layer called twice)
- # Serializing outputs of all blocks
- serialized = layers.Flatten(inputs)
- # Squared-Law Detection
- squared = tf.square(tf.abs(serialized))
- # Adding gaussian noise
- noisy = self.noise_layer.call(squared, training=True)
- return noisy
- class EndToEndAutoencoder(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- neighbouring_blocks,
- oversampling,
- channel_name='awgn'):
- # Number of leading/following messages
- if neighbouring_blocks % 2 == 0:
- neighbouring_blocks += 1
- # Oversampling rate
- oversampling = int(oversampling)
- # Transmission Channel : ['awgn', 'optical']
- channel_name = channel_name.strip().lower()
- self.encoder = tf.keras.Sequential([
- layers.Input(shape=(neighbouring_blocks, cardinality)),
- layers.Dense(2 * cardinality, activation='relu'),
- layers.Dense(2 * cardinality, activation='relu'),
- layers.Dense(samples_per_symbol),
- layers.ReLU(max_value=1.0)
- ])
- # Channel Model Layer
- if channel_name == 'optical':
- self.channel = OpticalChannel()
- elif channel_name == 'awgn':
- self.channel = AwgnChannel()
- else:
- raise TypeError("{} is not an accepted Channel Model.".format(channel_name))
- self.encoder = tf.keras.Sequential([
- ExtractCentralMessage(neighbouring_blocks, samples_per_symbol),
- layers.Dense(samples_per_symbol, activation='relu'),
- layers.Dense(2 * cardinality, activation='relu'),
- layers.Dense(2 * cardinality, activation='relu'),
- layers.Dense(cardinality, activation='softmax')
- ])
- def view_encoder(self):
- # TODO
- # Visualize encoder
- pass
- def call(self, x):
- tx = self.encoder(x)
- rx = self.channel(tx)
- y = self.decoder(rx)
- return y
- if __name__ == '__main__':
- # TODO
- # training/testing of autoencoder
- pass
|