|
|
@@ -6,6 +6,7 @@ from models.custom_layers import BitsToSymbols, SymbolsToBits
|
|
|
import numpy as np
|
|
|
import math
|
|
|
|
|
|
+from matplotlib import pyplot as plt
|
|
|
|
|
|
class BitMappingModel(tf.keras.Model):
|
|
|
def __init__(self,
|
|
|
@@ -33,6 +34,9 @@ class BitMappingModel(tf.keras.Model):
|
|
|
channel=channel,
|
|
|
bit_mapping=False)
|
|
|
|
|
|
+ self.bit_error_rate = []
|
|
|
+ self.symbol_error_rate = []
|
|
|
+
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
|
x1 = BitsToSymbols(self.cardinality)(inputs)
|
|
|
x2 = self.e2e_model(x1)
|
|
|
@@ -57,13 +61,38 @@ class BitMappingModel(tf.keras.Model):
|
|
|
|
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
- def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
- """
|
|
|
+ def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
+ X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
|
|
|
+ X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
|
|
|
|
|
|
- """
|
|
|
+ X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
|
|
|
+ X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
|
|
|
+
|
|
|
+ opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
+
|
|
|
+ self.compile(optimizer=opt,
|
|
|
+ loss=losses.BinaryCrossentropy(),
|
|
|
+ metrics=['accuracy'],
|
|
|
+ loss_weights=None,
|
|
|
+ weighted_metrics=None,
|
|
|
+ run_eagerly=False
|
|
|
+ )
|
|
|
+
|
|
|
+ self.fit(x=X_train,
|
|
|
+ y=y_train,
|
|
|
+ batch_size=batch_size,
|
|
|
+ epochs=epochs,
|
|
|
+ shuffle=True,
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
+ )
|
|
|
+
|
|
|
+ def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
for _ in range(iters):
|
|
|
self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
|
|
|
+
|
|
|
self.e2e_model.test()
|
|
|
+ self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
|
|
|
+ self.bit_error_rate.append(self.e2e_model.bit_error_rate)
|
|
|
|
|
|
X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
|
|
|
X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
|
|
|
@@ -89,32 +118,59 @@ class BitMappingModel(tf.keras.Model):
|
|
|
validation_data=(X_test, y_test)
|
|
|
)
|
|
|
|
|
|
- def test(self, num_of_blocks=1e4):
|
|
|
- pass
|
|
|
-
|
|
|
+ self.e2e_model.test()
|
|
|
+ self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
|
|
|
+ self.bit_error_rate.append(self.e2e_model.bit_error_rate)
|
|
|
|
|
|
SAMPLING_FREQUENCY = 336e9
|
|
|
CARDINALITY = 32
|
|
|
-SAMPLES_PER_SYMBOL = 24
|
|
|
+SAMPLES_PER_SYMBOL = 32
|
|
|
MESSAGES_PER_BLOCK = 9
|
|
|
DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
FIBER_LENGTH = 50
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
- dispersion_factor=DISPERSION_FACTOR,
|
|
|
- fiber_length=FIBER_LENGTH)
|
|
|
-
|
|
|
- model = BitMappingModel(cardinality=CARDINALITY,
|
|
|
- samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
- messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
- channel=optical_channel)
|
|
|
-
|
|
|
- # a , c = model.generate_random_inputs(num_of_blocks=1)
|
|
|
- #
|
|
|
- # a = tf.convert_to_tensor(a, dtype=tf.float32)
|
|
|
- # b = model(a)
|
|
|
-
|
|
|
- model.train(iters=1, num_of_blocks=1e4, epochs=1)
|
|
|
- model.summary()
|
|
|
+
|
|
|
+ distances = [0, 10, 20, 30, 40, 50, 60]
|
|
|
+ ser = []
|
|
|
+ ber = []
|
|
|
+
|
|
|
+ baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
|
|
|
+ bit_rate = math.log(CARDINALITY, 2) * baud_rate
|
|
|
+ snr = None
|
|
|
+
|
|
|
+ for d in distances:
|
|
|
+
|
|
|
+ optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
+ num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
+ dispersion_factor=DISPERSION_FACTOR,
|
|
|
+ fiber_length=d)
|
|
|
+
|
|
|
+ model = BitMappingModel(cardinality=CARDINALITY,
|
|
|
+ samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
+ messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
+ channel=optical_channel)
|
|
|
+
|
|
|
+ if snr is None:
|
|
|
+ snr = model.e2e_model.snr
|
|
|
+ elif snr != model.e2e_model.snr:
|
|
|
+ print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
|
|
|
+
|
|
|
+ # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
|
|
|
+
|
|
|
+ model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
|
|
|
+
|
|
|
+ ber.append(model.bit_error_rate[-1])
|
|
|
+ ser.append(model.symbol_error_rate[-1])
|
|
|
+
|
|
|
+ plt.plot(model.bit_error_rate, label='BER')
|
|
|
+ plt.plot(model.symbol_error_rate, label='SER')
|
|
|
+ plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
|
|
|
+ plt.legend()
|
|
|
+ plt.show()
|
|
|
+ # model.summary()
|
|
|
+
|
|
|
+ plt.plot(ber, label='BER')
|
|
|
+ plt.plot(ser, label='SER')
|
|
|
+ plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
|
|
|
+ plt.legend(ber)
|