|
@@ -7,15 +7,20 @@ from models.custom_layers import DigitizationLayer, OpticalChannel
|
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
|
import math
|
|
import math
|
|
|
|
|
|
|
|
-# plot frequency spectrum of e2e model
|
|
|
|
|
-def plot_e2e_spectrum(model_name=None):
|
|
|
|
|
|
|
+
|
|
|
|
|
+def plot_e2e_spectrum(model_name=None, num_samples=10000):
|
|
|
|
|
+ '''
|
|
|
|
|
+ Plot frequency spectrum of the output signal at the encoder
|
|
|
|
|
+ @param model_name: The name of the model to import. If None, then the latest model will be imported.
|
|
|
|
|
+ @param num_samples: The number of symbols to simulate when computing the spectrum.
|
|
|
|
|
+ '''
|
|
|
# Load pre-trained model
|
|
# Load pre-trained model
|
|
|
ae_model, params = load_model(model_name=model_name)
|
|
ae_model, params = load_model(model_name=model_name)
|
|
|
|
|
|
|
|
# Generate a list of random symbols (one hot encoded)
|
|
# Generate a list of random symbols (one hot encoded)
|
|
|
cat = [np.arange(params["cardinality"])]
|
|
cat = [np.arange(params["cardinality"])]
|
|
|
enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
|
|
|
- rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
|
|
|
|
|
|
|
+ rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
|
|
|
out = enc.fit_transform(rand_int)
|
|
out = enc.fit_transform(rand_int)
|
|
|
|
|
|
|
|
# Encode the list of symbols using the trained encoder
|
|
# Encode the list of symbols using the trained encoder
|
|
@@ -23,30 +28,25 @@ def plot_e2e_spectrum(model_name=None):
|
|
|
|
|
|
|
|
# Pass the output of the encoder through LPF
|
|
# Pass the output of the encoder through LPF
|
|
|
lpf = DigitizationLayer(fs=params["fs"],
|
|
lpf = DigitizationLayer(fs=params["fs"],
|
|
|
- num_of_samples=320000,
|
|
|
|
|
|
|
+ num_of_samples=params["cardinality"] * num_samples,
|
|
|
sig_avg=0)(a).numpy()
|
|
sig_avg=0)(a).numpy()
|
|
|
|
|
|
|
|
# Plot the frequency spectrum of the signal
|
|
# Plot the frequency spectrum of the signal
|
|
|
freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
|
|
freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
|
|
|
- mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
|
|
|
|
|
-
|
|
|
|
|
- a = np.fft.ifft(mul)
|
|
|
|
|
- a2 = np.power(a, 2)
|
|
|
|
|
- b = np.abs(np.fft.fft(a2))
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
plt.plot(freq, np.fft.fft(lpf), 'x')
|
|
plt.plot(freq, np.fft.fft(lpf), 'x')
|
|
|
plt.ylim((-500, 500))
|
|
plt.ylim((-500, 500))
|
|
|
plt.xlim((-5e10, 5e10))
|
|
plt.xlim((-5e10, 5e10))
|
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
|
|
- # plt.plot(freq, np.fft.fft(lpf), 'x')
|
|
|
|
|
- plt.plot(freq, b)
|
|
|
|
|
- plt.ylim((-500, 500))
|
|
|
|
|
- plt.xlim((-5e10, 5e10))
|
|
|
|
|
- plt.show()
|
|
|
|
|
|
|
|
|
|
def plot_e2e_encoded_output(model_name=None):
|
|
def plot_e2e_encoded_output(model_name=None):
|
|
|
|
|
+ '''
|
|
|
|
|
+ Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
|
|
|
|
|
+ The distorted DD received signal is also plotted.
|
|
|
|
|
+
|
|
|
|
|
+ @param model_name: The name of the model to import. If None, then the latest model will be imported.
|
|
|
|
|
+ '''
|
|
|
# Load pre-trained model
|
|
# Load pre-trained model
|
|
|
ae_model, params = load_model(model_name=model_name)
|
|
ae_model, params = load_model(model_name=model_name)
|
|
|
|
|
|
|
@@ -76,16 +76,15 @@ def plot_e2e_encoded_output(model_name=None):
|
|
|
|
|
|
|
|
for i in range(1, params["messages_per_block"]):
|
|
for i in range(1, params["messages_per_block"]):
|
|
|
plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
|
|
plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
|
|
|
- plt.axhline(y=0, color='black')
|
|
|
|
|
plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
|
|
plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
|
|
|
plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
|
|
plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
|
|
|
plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
|
|
plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
|
|
|
- plt.ylim((-0.1, 0.1))
|
|
|
|
|
plt.xlim((t.min(), t.max()))
|
|
plt.xlim((t.min(), t.max()))
|
|
|
plt.title(str(val[0, :, 0]))
|
|
plt.title(str(val[0, :, 0]))
|
|
|
plt.legend(loc='upper right')
|
|
plt.legend(loc='upper right')
|
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
|
- # plot_e2e_spectrum()
|
|
|
|
|
- plot_e2e_encoded_output()
|
|
|
|
|
|
|
+ plot_e2e_spectrum()
|
|
|
|
|
+ plot_e2e_encoded_output()
|