|
|
@@ -8,6 +8,7 @@ import math
|
|
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
+
|
|
|
class BitMappingModel(tf.keras.Model):
|
|
|
def __init__(self,
|
|
|
cardinality,
|
|
|
@@ -38,7 +39,7 @@ class BitMappingModel(tf.keras.Model):
|
|
|
self.symbol_error_rate = []
|
|
|
|
|
|
def call(self, inputs, training=None, mask=None):
|
|
|
- x1 = BitsToSymbols(self.cardinality)(inputs)
|
|
|
+ x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
|
|
|
x2 = self.e2e_model(x1)
|
|
|
out = SymbolsToBits(self.cardinality)(x2)
|
|
|
return out
|
|
|
@@ -71,7 +72,8 @@ class BitMappingModel(tf.keras.Model):
|
|
|
opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
|
|
|
self.compile(optimizer=opt,
|
|
|
- loss=losses.BinaryCrossentropy(),
|
|
|
+ # loss=losses.BinaryCrossentropy(),
|
|
|
+ loss=losses.MeanSquaredError(),
|
|
|
metrics=['accuracy'],
|
|
|
loss_weights=None,
|
|
|
weighted_metrics=None,
|
|
|
@@ -87,7 +89,9 @@ class BitMappingModel(tf.keras.Model):
|
|
|
)
|
|
|
|
|
|
def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
- for _ in range(iters):
|
|
|
+ for i in range(int(0.5*iters)):
|
|
|
+ print("Loop {}/{}".format(i, iters))
|
|
|
+
|
|
|
self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
|
|
|
|
|
|
self.e2e_model.test()
|
|
|
@@ -104,6 +108,7 @@ class BitMappingModel(tf.keras.Model):
|
|
|
|
|
|
self.compile(optimizer=opt,
|
|
|
loss=losses.BinaryCrossentropy(),
|
|
|
+ # loss=losses.MeanSquaredError(),
|
|
|
metrics=['accuracy'],
|
|
|
loss_weights=None,
|
|
|
weighted_metrics=None,
|
|
|
@@ -122,16 +127,64 @@ class BitMappingModel(tf.keras.Model):
|
|
|
self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
|
|
|
self.bit_error_rate.append(self.e2e_model.bit_error_rate)
|
|
|
|
|
|
+ for i in range(int(0.5*iters)):
|
|
|
+
|
|
|
+ X_train, y_train = self.generate_random_ inputs(int(1e5 * train_size))
|
|
|
+ X_test, y_test = self.generate_random_inputs(int(1e5 * (1 - train_size)))
|
|
|
+
|
|
|
+ X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
|
|
|
+ X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
|
|
|
+
|
|
|
+ opt = tf.keras.optimizers.Adam(learning_rate=lr)
|
|
|
+
|
|
|
+ self.compile(optimizer=opt,
|
|
|
+ loss=losses.BinaryCrossentropy(),
|
|
|
+ # loss=losses.MeanSquaredError(),
|
|
|
+ metrics=['accuracy'],
|
|
|
+ loss_weights=None,
|
|
|
+ weighted_metrics=None,
|
|
|
+ run_eagerly=False
|
|
|
+ )
|
|
|
+
|
|
|
+ self.fit(x=X_train,
|
|
|
+ y=y_train,
|
|
|
+ batch_size=batch_size,
|
|
|
+ epochs=epochs,
|
|
|
+ shuffle=True,
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
+ )
|
|
|
+
|
|
|
+ self.e2e_model.test()
|
|
|
+ self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
|
|
|
+ self.bit_error_rate.append(self.e2e_model.bit_error_rate)
|
|
|
+
|
|
|
+
|
|
|
SAMPLING_FREQUENCY = 336e9
|
|
|
-CARDINALITY = 32
|
|
|
-SAMPLES_PER_SYMBOL = 32
|
|
|
-MESSAGES_PER_BLOCK = 9
|
|
|
+CARDINALITY = 64
|
|
|
+SAMPLES_PER_SYMBOL = 48
|
|
|
+MESSAGES_PER_BLOCK = 11
|
|
|
DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
FIBER_LENGTH = 50
|
|
|
+ENOB = 6
|
|
|
+
|
|
|
+if __name__ == 'asd':
|
|
|
+ optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
+ num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
+ dispersion_factor=DISPERSION_FACTOR,
|
|
|
+ fiber_length=FIBER_LENGTH,
|
|
|
+ sig_avg=0.5,
|
|
|
+ enob=ENOB)
|
|
|
+
|
|
|
+ model = BitMappingModel(cardinality=CARDINALITY,
|
|
|
+ samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
+ messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
+ channel=optical_channel)
|
|
|
+
|
|
|
+ model.train()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
- distances = [0, 10, 20, 30, 40, 50, 60]
|
|
|
+ distances = [50]
|
|
|
ser = []
|
|
|
ber = []
|
|
|
|
|
|
@@ -144,7 +197,9 @@ if __name__ == '__main__':
|
|
|
optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
dispersion_factor=DISPERSION_FACTOR,
|
|
|
- fiber_length=d)
|
|
|
+ fiber_length=d,
|
|
|
+ sig_avg=0.5,
|
|
|
+ enob=ENOB)
|
|
|
|
|
|
model = BitMappingModel(cardinality=CARDINALITY,
|
|
|
samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
@@ -156,21 +211,60 @@ if __name__ == '__main__':
|
|
|
elif snr != model.e2e_model.snr:
|
|
|
print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
|
|
|
|
|
|
- # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
|
|
|
+ print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
|
|
|
|
|
|
model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
|
|
|
|
|
|
+ model.e2e_model.test(length_plot=True)
|
|
|
+
|
|
|
ber.append(model.bit_error_rate[-1])
|
|
|
ser.append(model.symbol_error_rate[-1])
|
|
|
|
|
|
- plt.plot(model.bit_error_rate, label='BER')
|
|
|
- plt.plot(model.symbol_error_rate, label='SER')
|
|
|
+ e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
|
|
|
+ samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
+ messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
+ channel=optical_channel,
|
|
|
+ bit_mapping=False)
|
|
|
+
|
|
|
+ ber1 = []
|
|
|
+ ser1 = []
|
|
|
+
|
|
|
+ for i in range(int(len(model.bit_error_rate))):
|
|
|
+ e2e_model.train(num_of_blocks=1e3, epochs=5)
|
|
|
+ e2e_model.test()
|
|
|
+
|
|
|
+ ber1.append(e2e_model.bit_error_rate)
|
|
|
+ ser1.append(e2e_model.symbol_error_rate)
|
|
|
+
|
|
|
+ # model2 = BitMappingModel(cardinality=CARDINALITY,
|
|
|
+ # samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
+ # messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
+ # channel=optical_channel)
|
|
|
+ #
|
|
|
+ # ber2 = []
|
|
|
+ # ser2 = []
|
|
|
+ #
|
|
|
+ # for i in range(int(len(model.bit_error_rate) / 2)):
|
|
|
+ # model2.train(num_of_blocks=1e3, epochs=5)
|
|
|
+ # model2.e2e_model.test()
|
|
|
+ #
|
|
|
+ # ber2.append(model2.e2e_model.bit_error_rate)
|
|
|
+ # ser2.append(model2.e2e_model.symbol_error_rate)
|
|
|
+
|
|
|
+ plt.plot(ber1, label='BER (1)')
|
|
|
+ # plt.plot(ser1, label='SER (1)')
|
|
|
+ # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
|
|
|
+ # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
|
|
|
+ plt.plot(model.bit_error_rate, label='BER (3)')
|
|
|
+ # plt.plot(model.symbol_error_rate, label='SER (3)')
|
|
|
+
|
|
|
plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
|
|
|
+ plt.yscale('log')
|
|
|
plt.legend()
|
|
|
plt.show()
|
|
|
# model.summary()
|
|
|
|
|
|
- plt.plot(ber, label='BER')
|
|
|
- plt.plot(ser, label='SER')
|
|
|
- plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
|
|
|
- plt.legend(ber)
|
|
|
+ # plt.plot(ber, label='BER')
|
|
|
+ # plt.plot(ser, label='SER')
|
|
|
+ # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
|
|
|
+ # plt.legend(ber)
|