2 Revize b8ff3d2e85 ... 2d60347040

Autor SHA1 Zpráva Datum
  Tharmetharan Balendran 2d60347040 resolved conflict před 4 roky
  Tharmetharan Balendran ab42d5f6a3 Refactoring/Clean-up před 4 roky
2 změnil soubory, kde provedl 25 přidání a 25 odebrání
  1. 1 1
      models/end_to_end.py
  2. 24 24
      models/plots.py

+ 1 - 1
models/end_to_end.py

@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
 from sklearn.metrics import accuracy_score
 from sklearn.metrics import accuracy_score
 from sklearn.preprocessing import OneHotEncoder
 from sklearn.preprocessing import OneHotEncoder
 from tensorflow.keras import layers, losses
 from tensorflow.keras import layers, losses
-from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, BitsToSymbols, SymbolsToBits
+from models.custom_layers import ExtractCentralMessage, OpticalChannel, DigitizationLayer, SymbolsToBits
 
 
 
 
 class EndToEndAutoencoder(tf.keras.Model):
 class EndToEndAutoencoder(tf.keras.Model):

+ 24 - 24
models/plots.py

@@ -7,44 +7,46 @@ from models.custom_layers import DigitizationLayer, OpticalChannel
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 import math
 import math
 
 
-# plot frequency spectrum of e2e model
-def plot_e2e_spectrum(model_name=None):
+
+def plot_e2e_spectrum(model_name=None, num_samples=10000):
+    '''
+    Plot frequency spectrum of the output signal at the encoder
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    @param num_samples: The number of symbols to simulate when computing the spectrum.
+    '''
     # Load pre-trained model
     # Load pre-trained model
     ae_model, params = load_model(model_name=model_name)
     ae_model, params = load_model(model_name=model_name)
 
 
     # Generate a list of random symbols (one hot encoded)
     # Generate a list of random symbols (one hot encoded)
     cat = [np.arange(params["cardinality"])]
     cat = [np.arange(params["cardinality"])]
     enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
     enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
-    rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
+    rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
     out = enc.fit_transform(rand_int)
     out = enc.fit_transform(rand_int)
 
 
     # Encode the list of symbols using the trained encoder
     # Encode the list of symbols using the trained encoder
-    enc = ae_model.encode_stream(out).flatten()
+    a = ae_model.encode_stream(out).flatten()
 
 
     # Pass the output of the encoder through LPF
     # Pass the output of the encoder through LPF
     lpf = DigitizationLayer(fs=params["fs"],
     lpf = DigitizationLayer(fs=params["fs"],
-                            num_of_samples=320000,
-                            sig_avg=0)(enc).numpy()
+                            num_of_samples=params["cardinality"] * num_samples,
+                            sig_avg=0)(a).numpy()
 
 
     # Plot the frequency spectrum of the signal
     # Plot the frequency spectrum of the signal
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
-    mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
-
-    a = np.fft.ifft(mul)
-    a2 = np.abs(np.power(a, 2))
-    b = np.fft.fft(a2)
 
 
-    plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x')
-    plt.title("Spectrum of Modulating Potential at Encoder")
-    plt.ylim((0, 500))
+    plt.plot(freq, np.fft.fft(lpf), 'x')
+    plt.ylim((-500, 500))
     plt.xlim((-5e10, 5e10))
     plt.xlim((-5e10, 5e10))
-    plt.xlabel("Freuquency / Hz")
-    plt.ylabel("Magnitude / au")
-    # plt.savefig('nn_encoder_spectrum.eps', format='eps')
-    # plt.savefig('nn_encoder_spectrum.png', format='png')
     plt.show()
     plt.show()
 
 
+
 def plot_e2e_encoded_output(model_name=None):
 def plot_e2e_encoded_output(model_name=None):
+    '''
+    Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
+    The distorted DD received signal is also plotted.
+
+    @param model_name: The name of the model to import. If None, then the latest model will be imported.
+    '''
     # Load pre-trained model
     # Load pre-trained model
     ae_model, params = load_model(model_name=model_name)
     ae_model, params = load_model(model_name=model_name)
 
 
@@ -74,17 +76,15 @@ def plot_e2e_encoded_output(model_name=None):
 
 
     for i in range(1, params["messages_per_block"]):
     for i in range(1, params["messages_per_block"]):
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
-
     plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
     plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
-    plt.plot(t, lpf_out.numpy().T, label='Modulating potential')
-    plt.plot(t, chan_out.numpy().flatten(), label='DD received signal')
+    plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
+    plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
+    plt.xlim((t.min(), t.max()))
     plt.title(str(val[0, :, 0]))
     plt.title(str(val[0, :, 0]))
     plt.legend(loc='upper right')
     plt.legend(loc='upper right')
-    plt.xlabel("Time / s")
-    plt.ylabel("Amplitude / V")
     plt.show()
     plt.show()
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     plot_e2e_spectrum()
     plot_e2e_spectrum()
-    # plot_e2e_encoded_output()
+    plot_e2e_encoded_output()