|
|
@@ -115,6 +115,18 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
*decoding_layers
|
|
|
], name="decoding_model")
|
|
|
|
|
|
+ def cost(self, y_true, y_pred):
|
|
|
+ symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
|
|
|
+
|
|
|
+ y_bits_true = SymbolsToBits(self.cardinality)(y_true)
|
|
|
+ y_bits_pred = SymbolsToBits(self.cardinality)(y_pred)
|
|
|
+
|
|
|
+ bit_cost = losses.BinaryCrossentropy()(y_bits_true, y_bits_pred)
|
|
|
+
|
|
|
+ a = 1
|
|
|
+
|
|
|
+ return symbol_cost + a * bit_cost
|
|
|
+
|
|
|
def generate_random_inputs(self, num_of_blocks, return_vals=False):
|
|
|
"""
|
|
|
A method that generates a list of one-hot encoded messages. This is utilized for generating the test/train data.
|
|
|
@@ -179,7 +191,8 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
if self.bit_mapping:
|
|
|
loss_fn = losses.BinaryCrossentropy()
|
|
|
else:
|
|
|
- loss_fn = losses.CategoricalCrossentropy()
|
|
|
+ # loss_fn = losses.CategoricalCrossentropy()
|
|
|
+ loss_fn = self.cost
|
|
|
|
|
|
self.compile(optimizer=opt,
|
|
|
loss=loss_fn,
|
|
|
@@ -197,7 +210,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
validation_data=(X_test, y_test)
|
|
|
)
|
|
|
|
|
|
- def test(self, num_of_blocks=1e4, length_plot=False):
|
|
|
+ def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
|
|
|
X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
|
|
|
|
|
|
y_out = self.call(X_test)
|
|
|
@@ -249,7 +262,8 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
plt.plot(lengths, ber_l)
|
|
|
plt.yscale('log')
|
|
|
- plt.show()
|
|
|
+ if plt_show:
|
|
|
+ plt.show()
|
|
|
|
|
|
print("SYMBOL ERROR RATE: {}".format(self.symbol_error_rate))
|
|
|
print("BIT ERROR RATE: {}".format(self.bit_error_rate))
|
|
|
@@ -370,26 +384,48 @@ CARDINALITY = 32
|
|
|
SAMPLES_PER_SYMBOL = 32
|
|
|
MESSAGES_PER_BLOCK = 9
|
|
|
DISPERSION_FACTOR = -21.7 * 1e-24
|
|
|
-FIBER_LENGTH = 0
|
|
|
+FIBER_LENGTH = 50
|
|
|
+FIBER_LENGTH_STDDEV = 5
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
- num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
- dispersion_factor=DISPERSION_FACTOR,
|
|
|
- fiber_length=FIBER_LENGTH)
|
|
|
-
|
|
|
- ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
|
|
|
- samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
- messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
- channel=optical_channel,
|
|
|
- bit_mapping=False)
|
|
|
-
|
|
|
- ae_model.train(num_of_blocks=1e5, epochs=5)
|
|
|
- ae_model.test()
|
|
|
- ae_model.view_encoder()
|
|
|
- ae_model.view_sample_block()
|
|
|
- # ae_model.summary()
|
|
|
- ae_model.encoder.summary()
|
|
|
- ae_model.channel.summary()
|
|
|
- ae_model.decoder.summary()
|
|
|
+
|
|
|
+ stddevs = [0, 1, 5, 10]
|
|
|
+ legend = []
|
|
|
+
|
|
|
+ for s in stddevs:
|
|
|
+ optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
|
|
|
+ num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
|
|
|
+ dispersion_factor=DISPERSION_FACTOR,
|
|
|
+ fiber_length=FIBER_LENGTH,
|
|
|
+ fiber_length_stddev=s,
|
|
|
+ lpf_cutoff=32e9,
|
|
|
+ rx_stddev=0.01,
|
|
|
+ sig_avg=0.5,
|
|
|
+ enob=10)
|
|
|
+
|
|
|
+ ae_model = EndToEndAutoencoder(cardinality=CARDINALITY,
|
|
|
+ samples_per_symbol=SAMPLES_PER_SYMBOL,
|
|
|
+ messages_per_block=MESSAGES_PER_BLOCK,
|
|
|
+ channel=optical_channel,
|
|
|
+ bit_mapping=False)
|
|
|
+
|
|
|
+ print(ae_model.snr)
|
|
|
+
|
|
|
+ ae_model.train(num_of_blocks=3e5, epochs=5)
|
|
|
+ ae_model.test(length_plot=True, plt_show=False)
|
|
|
+ # plt.legend(['{} +/- {}'.format(FIBER_LENGTH, s)])
|
|
|
+
|
|
|
+ legend.append('{} +/- {}'.format(FIBER_LENGTH, s))
|
|
|
+
|
|
|
+ plt.legend(legend)
|
|
|
+ plt.show()
|
|
|
+ plt.savefig('ber_vs_length.eps', format='eps')
|
|
|
+
|
|
|
+ # ae_model.view_encoder()
|
|
|
+ # ae_model.view_sample_block()
|
|
|
+ # # ae_model.summary()
|
|
|
+ # ae_model.encoder.summary()
|
|
|
+ # ae_model.channel.summary()
|
|
|
+ # ae_model.decoder.summary()
|
|
|
pass
|