|
|
@@ -148,7 +148,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
return enc_weights, dec_weights
|
|
|
|
|
|
def encode_stream(self, x):
|
|
|
- enc_weights, dec_weights = self.extract_weights()
|
|
|
+ enc_weights, _ = self.extract_weights()
|
|
|
|
|
|
for i in range(len(enc_weights) // 2):
|
|
|
x = np.matmul(x, enc_weights[2 * i]) + enc_weights[2 * i + 1]
|
|
|
@@ -160,6 +160,19 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
return x
|
|
|
|
|
|
+ def decode_stream(self, x):
|
|
|
+ _, dec_weights = self.extract_weights()
|
|
|
+
|
|
|
+ for i in range(len(dec_weights) // 2):
|
|
|
+ x = np.matmul(x, dec_weights[2 * i]) + dec_weights[2 * i + 1]
|
|
|
+
|
|
|
+ if i == len(dec_weights) // 2 - 1:
|
|
|
+ x = tf.keras.activations.softmax(x).numpy()
|
|
|
+ else:
|
|
|
+ x = tf.keras.activations.relu(x).numpy()
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
def cost(self, y_true, y_pred):
|
|
|
symbol_cost = losses.CategoricalCrossentropy()(y_true, y_pred)
|
|
|
|
|
|
@@ -198,7 +211,7 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
|
|
|
return out_arr, out_arr[:, mid_idx, :]
|
|
|
|
|
|
- def train(self, num_of_blocks=1e6, epochs=1, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
+ def train(self, num_of_blocks=1e6, epochs=50, batch_size=None, train_size=0.8, lr=1e-3):
|
|
|
"""
|
|
|
Method to train the autoencoder. Further configuration to the loss function, optimizer etc. can be made in here.
|
|
|
|
|
|
@@ -217,6 +230,8 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
else:
|
|
|
loss_fn = losses.CategoricalCrossentropy()
|
|
|
|
|
|
+ callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
|
|
+
|
|
|
self.compile(optimizer=opt,
|
|
|
loss=loss_fn,
|
|
|
metrics=['accuracy'],
|
|
|
@@ -225,13 +240,18 @@ class EndToEndAutoencoder(tf.keras.Model):
|
|
|
run_eagerly=False
|
|
|
)
|
|
|
|
|
|
- self.fit(x=X_train,
|
|
|
- y=y_train,
|
|
|
- batch_size=batch_size,
|
|
|
- epochs=epochs,
|
|
|
- shuffle=True,
|
|
|
- validation_data=(X_test, y_test)
|
|
|
- )
|
|
|
+ history = self.fit(x=X_train,
|
|
|
+ y=y_train,
|
|
|
+ batch_size=batch_size,
|
|
|
+ epochs=epochs,
|
|
|
+ callbacks=[callback],
|
|
|
+ shuffle=True,
|
|
|
+ validation_data=(X_test, y_test)
|
|
|
+ )
|
|
|
+
|
|
|
+ if len(history.history['loss']) == epochs:
|
|
|
+ print("The model trained for the maximum number of epochs and may not have converged to a good solution. "
|
|
|
+ "Setting a higher epoch number and retraining is recommended")
|
|
|
|
|
|
def test(self, num_of_blocks=1e4, length_plot=False, plt_show=True):
|
|
|
X_test, y_test = self.generate_random_inputs(int(num_of_blocks))
|
|
|
@@ -434,6 +454,52 @@ def load_model(model_name=None):
|
|
|
return ae_model, params
|
|
|
|
|
|
|
|
|
+if __name__ == 'asd':
|
|
|
+
|
|
|
+ params = {"fs": 336e9,
|
|
|
+ "cardinality": 32,
|
|
|
+ "samples_per_symbol": 32,
|
|
|
+ "messages_per_block": 9,
|
|
|
+ "dispersion_factor": (-21.7 * 1e-24),
|
|
|
+ "fiber_length": 50,
|
|
|
+ "fiber_length_stddev": 1,
|
|
|
+ "lpf_cutoff": 32e9,
|
|
|
+ "rx_stddev": 0.01,
|
|
|
+ "sig_avg": 0.5,
|
|
|
+ "enob": 8,
|
|
|
+ "custom_loss_fn": True
|
|
|
+ }
|
|
|
+
|
|
|
+ lengths = np.linspace(40, 100, 7)
|
|
|
+ ber = []
|
|
|
+ for len_ in lengths:
|
|
|
+ optical_channel = OpticalChannel(fs=params["fs"],
|
|
|
+ num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
|
|
|
+ dispersion_factor=params["dispersion_factor"],
|
|
|
+ fiber_length=len_,
|
|
|
+ fiber_length_stddev=params["fiber_length_stddev"],
|
|
|
+ lpf_cutoff=params["lpf_cutoff"],
|
|
|
+ rx_stddev=0,
|
|
|
+ sig_avg=0,
|
|
|
+ enob=params["enob"])
|
|
|
+
|
|
|
+ ae_model = EndToEndAutoencoder(cardinality=params["cardinality"],
|
|
|
+ samples_per_symbol=params["samples_per_symbol"],
|
|
|
+ messages_per_block=params["messages_per_block"],
|
|
|
+ channel=optical_channel,
|
|
|
+ custom_loss_fn=params["custom_loss_fn"])
|
|
|
+ ae_model.train(num_of_blocks=1e5)
|
|
|
+ ae_model.test()
|
|
|
+ ber.append(ae_model.bit_error_rate)
|
|
|
+
|
|
|
+ plt.plot(lengths, ber)
|
|
|
+ plt.title("Bit Error Rate at different trained lengths")
|
|
|
+ plt.yscale('log')
|
|
|
+ plt.xlabel("Fiber Length / km")
|
|
|
+ plt.ylabel("Bit Error Rate")
|
|
|
+ plt.show()
|
|
|
+ pass
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
params = {"fs": 336e9,
|
|
|
@@ -452,7 +518,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
force_training = False
|
|
|
|
|
|
- model_save_name = ""
|
|
|
+ model_save_name = "20210317-124015"
|
|
|
param_file_path = os.path.join("exports", model_save_name, "params.json")
|
|
|
|
|
|
if os.path.isfile(param_file_path) and not force_training:
|
|
|
@@ -480,7 +546,23 @@ if __name__ == '__main__':
|
|
|
ae_model.encoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "encoder"))
|
|
|
ae_model.decoder = tf.keras.models.load_model(os.path.join("exports", model_save_name, "decoder"))
|
|
|
else:
|
|
|
- ae_model.train(num_of_blocks=1e5, epochs=5)
|
|
|
+ ae_model.train(num_of_blocks=1e4)
|
|
|
ae_model.save_end_to_end()
|
|
|
|
|
|
+ ae_model.view_encoder()
|
|
|
+ ae_model.test()
|
|
|
+
|
|
|
+ # cat = [np.arange(32)]
|
|
|
+ # enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
|
+ #
|
|
|
+ # inp = np.asarray([9, 28, 15, 18, 23, 0, 29, 30, 2]).reshape(-1, 1)
|
|
|
+ # inp_oh = enc.fit_transform(inp)
|
|
|
+ #
|
|
|
+ # out = ae_model(inp_oh.reshape(1, 9, 32))
|
|
|
+ #
|
|
|
+ # a = out.numpy()
|
|
|
+ #
|
|
|
+ # plt.plot(a)
|
|
|
+ # plt.show()
|
|
|
+
|
|
|
pass
|