| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import tensorflow as tf
- from tensorflow.keras import losses
- from models.custom_layers import OpticalChannel
- from models.end_to_end import EndToEndAutoencoder
- from models.custom_layers import BitsToSymbols, SymbolsToBits
- import numpy as np
- import math
- from matplotlib import pyplot as plt
- class BitMappingModel(tf.keras.Model):
- def __init__(self,
- cardinality,
- samples_per_symbol,
- messages_per_block,
- channel):
- super(BitMappingModel, self).__init__()
- # Labelled M in paper
- self.cardinality = cardinality
- self.bits_per_symbol = int(math.log(self.cardinality, 2))
- # Labelled n in paper
- self.samples_per_symbol = samples_per_symbol
- # Labelled N in paper
- if messages_per_block % 2 == 0:
- messages_per_block += 1
- self.messages_per_block = messages_per_block
- self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
- samples_per_symbol=self.samples_per_symbol,
- messages_per_block=self.messages_per_block,
- channel=channel,
- custom_loss_fn=False)
- self.bit_error_rate = []
- self.symbol_error_rate = []
- def call(self, inputs, training=None, mask=None):
- x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
- x2 = self.e2e_model(x1)
- out = SymbolsToBits(self.cardinality)(x2)
- return out
- def generate_random_inputs(self, num_of_blocks, return_vals=False):
- """
- """
- mid_idx = int((self.messages_per_block - 1) / 2)
- rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
- out = rand_int
- out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
- if return_vals:
- return out_arr, out_arr, out_arr[:, mid_idx, :]
- return out_arr, out_arr[:, mid_idx, :]
- def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
- X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
- X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
- opt = tf.keras.optimizers.Adam(learning_rate=lr)
- self.compile(optimizer=opt,
- loss=losses.MeanSquaredError(),
- metrics=['accuracy'],
- loss_weights=None,
- weighted_metrics=None,
- run_eagerly=False
- )
- self.fit(x=X_train,
- y=y_train,
- batch_size=batch_size,
- epochs=epochs,
- shuffle=True,
- validation_data=(X_test, y_test)
- )
- def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
- for i in range(int(iters)):
- print("Loop {}/{}".format(i, iters))
- self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
- self.e2e_model.test()
- self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
- self.bit_error_rate.append(self.e2e_model.bit_error_rate)
- X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
- X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
- X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
- X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
- opt = tf.keras.optimizers.Adam(learning_rate=lr)
- self.compile(optimizer=opt,
- loss=losses.BinaryCrossentropy(),
- metrics=['accuracy'],
- loss_weights=None,
- weighted_metrics=None,
- run_eagerly=False
- )
- self.fit(x=X_train,
- y=y_train,
- batch_size=batch_size,
- epochs=epochs,
- shuffle=True,
- validation_data=(X_test, y_test)
- )
- self.e2e_model.test()
- self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
- self.bit_error_rate.append(self.e2e_model.bit_error_rate)
- SAMPLING_FREQUENCY = 336e9
- CARDINALITY = 64
- SAMPLES_PER_SYMBOL = 48
- MESSAGES_PER_BLOCK = 11
- DISPERSION_FACTOR = -21.7 * 1e-24
- FIBER_LENGTH = 50
- ENOB = 6
- if __name__ == 'asd':
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=FIBER_LENGTH,
- sig_avg=0.5,
- enob=ENOB)
- model = BitMappingModel(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel)
- model.train()
- if __name__ == '__main__':
- distances = [50]
- ser = []
- ber = []
- baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
- bit_rate = math.log(CARDINALITY, 2) * baud_rate
- snr = None
- for d in distances:
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=d,
- sig_avg=0.5,
- enob=ENOB)
- model = BitMappingModel(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel)
- if snr is None:
- snr = model.e2e_model.snr
- elif snr != model.e2e_model.snr:
- print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
- print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
- model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
- model.e2e_model.test(length_plot=True)
- ber.append(model.bit_error_rate[-1])
- ser.append(model.symbol_error_rate[-1])
- e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
- samples_per_symbol=SAMPLES_PER_SYMBOL,
- messages_per_block=MESSAGES_PER_BLOCK,
- channel=optical_channel,
- custom_loss_fn=False)
- ber1 = []
- ser1 = []
- for i in range(int(len(model.bit_error_rate))):
- e2e_model.train(num_of_blocks=1e3, epochs=5)
- e2e_model.test()
- ber1.append(e2e_model.bit_error_rate)
- ser1.append(e2e_model.symbol_error_rate)
- # model2 = BitMappingModel(cardinality=CARDINALITY,
- # samples_per_symbol=SAMPLES_PER_SYMBOL,
- # messages_per_block=MESSAGES_PER_BLOCK,
- # channel=optical_channel)
- #
- # ber2 = []
- # ser2 = []
- #
- # for i in range(int(len(model.bit_error_rate) / 2)):
- # model2.train(num_of_blocks=1e3, epochs=5)
- # model2.e2e_model.test()
- #
- # ber2.append(model2.e2e_model.bit_error_rate)
- # ser2.append(model2.e2e_model.symbol_error_rate)
- plt.plot(ber1, label='BER (1)')
- # plt.plot(ser1, label='SER (1)')
- # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
- # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
- plt.plot(model.bit_error_rate, label='BER (3)')
- # plt.plot(model.symbol_error_rate, label='SER (3)')
- plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
- plt.yscale('log')
- plt.legend()
- plt.show()
- # model.summary()
- # plt.plot(ber, label='BER')
- # plt.plot(ser, label='SER')
- # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
- # plt.legend(ber)
|