| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import tensorflow as tf
- from tensorflow.keras import layers, losses
- from models.custom_layers import ExtractCentralMessage, OpticalChannel
- from models.end_to_end import EndToEndAutoencoder
- from models.custom_layers import BitsToSymbols, SymbolsToBits
- import numpy as np
- import math
- class BitMappingModel(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- messages_per_block,
- channel):
- super(BitMappingModel, self).__init__()
- # Labelled M in paper
- self.cardinality = cardinality
- self.bits_per_symbol = int(math.log(self.cardinality, 2))
- # Labelled n in paper
- self.samples_per_symbol = samples_per_symbol
- # Labelled N in paper
- if messages_per_block % 2 == 0:
- messages_per_block += 1
- self.messages_per_block = messages_per_block
- self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
- samples_per_symbol=self.samples_per_symbol,
- messages_per_block=self.messages_per_block,
- channel=channel,
- bit_mapping=False)
- def call(self, inputs, training=None, mask=None):
- x1 = BitsToSymbols(self.cardinality)(inputs)
- x2 = self.e2e_model(x1)
- out = SymbolsToBits(self.cardinality)(x2)
- return out
- def generate_random_inputs(self, num_of_blocks, return_vals=False):
- """
- """
- mid_idx = int((self.messages_per_block - 1) / 2)
- rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
- out = rand_int
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
- if return_vals:
- return out_arr, out_arr, out_arr[:, mid_idx, :]
- return out_arr, out_arr[:, mid_idx, :]
- def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
- """
- """
- for _ in range(iters):
- self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
- self.e2e_model.test()
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
- X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
- X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
- opt = tf.keras.optimizers.Adam(learning_rate=lr)
- self.compile(optimizer=opt,
- loss=losses.BinaryCrossentropy(),
- metrics=['accuracy'],
- loss_weights=None,
- weighted_metrics=None,
- run_eagerly=False
- )
- self.fit(x=X_train,
- y=y_train,
- batch_size=batch_size,
- epochs=epochs,
- shuffle=True,
- validation_data=(X_test, y_test)
- )
- def test(self, num_of_blocks=1e4):
- pass
- SAMPLING_FREQUENCY = 336e9
- CARDINALITY = 32
- SAMPLES_PER_SYMBOL = 24
- MESSAGES_PER_BLOCK = 9
- DISPERSION_FACTOR = -21.7 * 1e-24
- FIBER_LENGTH = 50
- if __name__ == '__main__':
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=FIBER_LENGTH)
- model = BitMappingModel(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel)
- # a , c = model.generate_random_inputs(num_of_blocks=1)
- #
- # a = tf.convert_to_tensor(a, dtype=tf.float32)
- # b = model(a)
- model.train(iters=1, num_of_blocks=1e4, epochs=1)
- model.summary()
|