import tensorflow as tf from tensorflow.keras import layers, losses from models.custom_layers import ExtractCentralMessage, OpticalChannel from models.end_to_end import EndToEndAutoencoder from models.custom_layers import BitsToSymbols, SymbolsToBits import numpy as np import math class BitMappingModel(tf.keras.Model): def __init__(self, cardinality, samples_per_symbol, messages_per_block, channel): super(BitMappingModel, self).__init__() # Labelled M in paper self.cardinality = cardinality self.bits_per_symbol = int(math.log(self.cardinality, 2)) # Labelled n in paper self.samples_per_symbol = samples_per_symbol # Labelled N in paper if messages_per_block % 2 == 0: messages_per_block += 1 self.messages_per_block = messages_per_block self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality, samples_per_symbol=self.samples_per_symbol, messages_per_block=self.messages_per_block, channel=channel, bit_mapping=False) def call(self, inputs, training=None, mask=None): x1 = BitsToSymbols(self.cardinality)(inputs) x2 = self.e2e_model(x1) out = SymbolsToBits(self.cardinality)(x2) return out def generate_random_inputs(self, num_of_blocks, return_vals=False): """ """ mid_idx = int((self.messages_per_block - 1) / 2) rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1)) out = rand_int out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol)) if return_vals: return out_arr, out_arr, out_arr[:, mid_idx, :] return out_arr, out_arr[:, mid_idx, :] def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3): """ """ for _ in range(iters): self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs) self.e2e_model.test() X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size)) X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size))) X_train = tf.convert_to_tensor(X_train, dtype=tf.float32) X_test = tf.convert_to_tensor(X_test, dtype=tf.float32) opt = tf.keras.optimizers.Adam(learning_rate=lr) self.compile(optimizer=opt, loss=losses.BinaryCrossentropy(), metrics=['accuracy'], loss_weights=None, weighted_metrics=None, run_eagerly=False ) self.fit(x=X_train, y=y_train, batch_size=batch_size, epochs=epochs, shuffle=True, validation_data=(X_test, y_test) ) def test(self, num_of_blocks=1e4): pass SAMPLING_FREQUENCY = 336e9 CARDINALITY = 32 SAMPLES_PER_SYMBOL = 24 MESSAGES_PER_BLOCK = 9 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 if __name__ == '__main__': optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH) model = BitMappingModel(cardinality=CARDINALITY, samples_per_symbol=SAMPLES_PER_SYMBOL, messages_per_block=MESSAGES_PER_BLOCK, channel=optical_channel) # a , c = model.generate_random_inputs(num_of_blocks=1) # # a = tf.convert_to_tensor(a, dtype=tf.float32) # b = model(a) model.train(iters=1, num_of_blocks=1e4, epochs=1) model.summary()