| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from sklearn.preprocessing import OneHotEncoder
- import numpy as np
- from tensorflow.keras import layers
- from end_to_end import load_model
- from models.custom_layers import DigitizationLayer, OpticalChannel
- from matplotlib import pyplot as plt
- import math
- def plot_e2e_spectrum(model_name=None, num_samples=10000, plot_theoretical_nulls=False):
- '''
- Plot frequency spectrum of the output signal at the encoder
- @param model_name: The name of the model to import. If None, then the latest model will be imported.
- @param num_samples: The number of symbols to simulate when computing the spectrum.
- '''
- # Load pre-trained model
- ae_model, params = load_model(model_name=model_name)
- # Generate a list of random symbols (one hot encoded)
- cat = [np.arange(params["cardinality"])]
- enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
- rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
- out = enc.fit_transform(rand_int)
- # Encode the list of symbols using the trained encoder
- a = ae_model.encode_stream(out).flatten()
- # Pass the output of the encoder through LPF
- lpf = DigitizationLayer(fs=params["fs"],
- num_of_samples=params["cardinality"] * num_samples,
- sig_avg=0)(a).numpy()
- # Plot the frequency spectrum of the signal
- freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
- plt.plot(freq, np.fft.fft(lpf), 'x')
- plt.ylim((-500, 500))
- plt.xlim((-5e10, 5e10))
- if plot_theoretical_nulls:
- f_curr = np.sqrt(np.pi / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
- i = 1
- plt.axvline(x=f_curr, color="black", linestyle="--", label="null frequency")
- plt.axvline(x=-f_curr, color="black", linestyle="--")
- plt.axvline(x=2*f_curr, color="red", linestyle="--", label="double of null frequency")
- plt.axvline(x=-2*f_curr, color="red", linestyle="--")
- while f_curr < params["lpf_cutoff"]:
- f_curr = np.sqrt(((2 * i + 1) * np.pi) / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
- if 2 * f_curr < params["lpf_cutoff"]:
- plt.axvline(x=2 * f_curr, color="red", linestyle="--")
- plt.axvline(x=-2 * f_curr, color="red", linestyle="--")
- plt.axvline(x=f_curr, color="black", linestyle="--")
- plt.axvline(x=-f_curr, color="black", linestyle="--")
- i += 1
- plt.legend(loc=0)
- plt.show()
- def plot_e2e_encoded_output(model_name=None):
- '''
- Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
- The distorted DD received signal is also plotted.
- @param model_name: The name of the model to import. If None, then the latest model will be imported.
- '''
- # Load pre-trained model
- ae_model, params = load_model(model_name=model_name)
- # Generate a random block of messages
- val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
- # Encode and flatten the messages
- enc = ae_model.encoder(inp)
- flat_enc = layers.Flatten()(enc)
- chan_out = ae_model.channel.layers[1](flat_enc)
- # Instantiate LPF layer
- lpf = DigitizationLayer(fs=params["fs"],
- num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
- sig_avg=0)
- # Apply LPF
- lpf_out = lpf(flat_enc)
- # Time axis
- t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
- if isinstance(ae_model.channel.layers[1], OpticalChannel):
- t = t / params["fs"]
- # Plot the concatenated symbols before and after LPF
- plt.figure(figsize=(2 * params["messages_per_block"], 6))
- for i in range(1, params["messages_per_block"]):
- plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
- plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
- plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
- plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
- plt.xlim((t.min(), t.max()))
- plt.title(str(val[0, :, 0]))
- plt.legend(loc='upper right')
- plt.show()
- if __name__ == '__main__':
- plot_e2e_spectrum('20210317-124015', plot_theoretical_nulls=True)
- # plot_e2e_encoded_output()
|