소스 검색

plot formatting

Tharmetharan Balendran 4 년 전
부모
커밋
690737b0ad
2개의 변경된 파일45개의 추가작업 그리고 20개의 파일을 삭제
  1. 15 4
      models/end_to_end.py
  2. 30 16
      models/plots.py

+ 15 - 4
models/end_to_end.py

@@ -341,8 +341,11 @@ class EndToEndAutoencoder(tf.keras.Model):
         num_x = int(2 ** i)
         num_y = int(self.cardinality / num_x)
 
+        # num_x = 6
+        # num_y = 11
+
         # Plot all symbols
-        fig, axs = plt.subplots(num_y, num_x, figsize=(2.5 * num_x, 2 * num_y))
+        fig, axs = plt.subplots(num_y, num_x, figsize=(1.875 * num_x, 1.5 * num_y))
 
         t = np.arange(self.samples_per_symbol)
         if isinstance(self.channel.layers[1], OpticalChannel):
@@ -351,9 +354,14 @@ class EndToEndAutoencoder(tf.keras.Model):
         sym_idx = 0
         for y in range(num_y):
             for x in range(num_x):
-                axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'x')
-                axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
-                sym_idx += 1
+                try:
+                    axs[y, x].plot(t, enc_messages[sym_idx].numpy().flatten(), 'k.')
+                    # axs[y, x].vlines(t, 0, enc_messages[sym_idx].numpy().flatten(), linestyle="dashed", color='k', linewidth=0.1)
+                    # axs[y, x].set_title('Symbol {}'.format(str(sym_idx)))
+                    sym_idx += 1
+                except tf.python.framework.errors_impl.InvalidArgumentError:
+                    break
+
 
         for ax in axs.flat:
             ax.set(xlabel='Time', ylabel='Amplitude', ylim=(0, 1))
@@ -361,6 +369,9 @@ class EndToEndAutoencoder(tf.keras.Model):
         for ax in axs.flat:
             ax.label_outer()
 
+        fig.suptitle("List of encoded symbols [{}-{}]".format(0, self.cardinality-1))
+        plt.tight_layout()
+        plt.savefig('list_encoded_symbols.eps', format='eps')
         plt.show()
         pass
 

+ 30 - 16
models/plots.py

@@ -28,38 +28,43 @@ def plot_e2e_spectrum(model_name=None, num_samples=10000, plot_theoretical_nulls
 
     # Pass the output of the encoder through LPF
     lpf = DigitizationLayer(fs=params["fs"],
-                            num_of_samples=params["cardinality"] * num_samples,
+                            num_of_samples=params["samples_per_symbol"] * num_samples,
                             sig_avg=0)(a).numpy()
 
     # Plot the frequency spectrum of the signal
     freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
 
-    plt.plot(freq, np.fft.fft(lpf), 'x')
-    plt.ylim((-500, 500))
-    plt.xlim((-5e10, 5e10))
+    plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x')
+    plt.ylim((0, 500))
+    plt.xlim((0, 4e10))
 
     if plot_theoretical_nulls:
         f_curr = np.sqrt(np.pi / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
         i = 1
 
+        print(f_curr)
+
         plt.axvline(x=f_curr, color="black", linestyle="--", label="null frequency")
         plt.axvline(x=-f_curr, color="black", linestyle="--")
         plt.axvline(x=2*f_curr, color="red", linestyle="--", label="double of null frequency")
         plt.axvline(x=-2*f_curr, color="red", linestyle="--")
 
-        while f_curr < params["lpf_cutoff"]:
-            f_curr = np.sqrt(((2 * i + 1) * np.pi) / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
+        f_curr = np.sqrt((3 * np.pi) / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
 
-            if 2 * f_curr < params["lpf_cutoff"]:
-                plt.axvline(x=2 * f_curr, color="red", linestyle="--")
-                plt.axvline(x=-2 * f_curr, color="red", linestyle="--")
+        print(f_curr)
+        plt.axvline(x=2 * f_curr, color="red", linestyle="--")
+        plt.axvline(x=-2 * f_curr, color="red", linestyle="--")
 
-            plt.axvline(x=f_curr, color="black", linestyle="--")
-            plt.axvline(x=-f_curr, color="black", linestyle="--")
-            i += 1
+        plt.axvline(x=f_curr, color="black", linestyle="--")
+        plt.axvline(x=-f_curr, color="black", linestyle="--")
 
         plt.legend(loc=0)
 
+    plt.title("Frequency Spectrum of transmitted signal")
+    plt.xlabel("Frequency")
+    plt.ylabel("Magnitude")
+    plt.tight_layout()
+    plt.savefig('encoder_spectrum.eps', format='eps')
     plt.show()
 
 def plot_e2e_encoded_output(model_name=None):
@@ -94,7 +99,7 @@ def plot_e2e_encoded_output(model_name=None):
         t = t / params["fs"]
 
     # Plot the concatenated symbols before and after LPF
-    plt.figure(figsize=(2 * params["messages_per_block"], 6))
+    plt.figure(figsize=(1.4*params["messages_per_block"], 4.5))
 
     for i in range(1, params["messages_per_block"]):
         plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
@@ -102,11 +107,20 @@ def plot_e2e_encoded_output(model_name=None):
     plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
     plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
     plt.xlim((t.min(), t.max()))
-    plt.title(str(val[0, :, 0]))
+    plt.title("Sample block of transmitted/received waveform")
+    print(str(val[0, :, 0]))
+    # plt.text(0, 0, str(val[0, :, 0]))
     plt.legend(loc='upper right')
+    plt.tight_layout()
+    plt.savefig('sample_transmission_block.eps', format='eps')
     plt.show()
 
 
 if __name__ == '__main__':
-    plot_e2e_spectrum('20210317-124015', plot_theoretical_nulls=True)
-    # plot_e2e_encoded_output()
+    ae_model, params = load_model(model_name='20210502-135450')
+
+    ae_model.view_encoder()
+    # ae_model.view_sample_block()
+    #
+    plot_e2e_spectrum('20210502-135450', plot_theoretical_nulls=True)
+    # plot_e2e_encoded_output('20210502-135450')