from sklearn.preprocessing import OneHotEncoder import numpy as np from tensorflow.keras import layers from end_to_end import load_model from models.custom_layers import DigitizationLayer, OpticalChannel from matplotlib import pyplot as plt import math def plot_e2e_spectrum(model_name=None, num_samples=10000, plot_theoretical_nulls=False): ''' Plot frequency spectrum of the output signal at the encoder @param model_name: The name of the model to import. If None, then the latest model will be imported. @param num_samples: The number of symbols to simulate when computing the spectrum. ''' # Load pre-trained model ae_model, params = load_model(model_name=model_name) # Generate a list of random symbols (one hot encoded) cat = [np.arange(params["cardinality"])] enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat) rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1)) out = enc.fit_transform(rand_int) # Encode the list of symbols using the trained encoder a = ae_model.encode_stream(out).flatten() # Pass the output of the encoder through LPF lpf = DigitizationLayer(fs=params["fs"], num_of_samples=params["samples_per_symbol"] * num_samples, sig_avg=0)(a).numpy() # Plot the frequency spectrum of the signal freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"]) plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x') plt.ylim((0, 500)) plt.xlim((0, 4e10)) if plot_theoretical_nulls: f_curr = np.sqrt(np.pi / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi) i = 1 print(f_curr) plt.axvline(x=f_curr, color="black", linestyle="--", label="null frequency") plt.axvline(x=-f_curr, color="black", linestyle="--") plt.axvline(x=2*f_curr, color="red", linestyle="--", label="double of null frequency") plt.axvline(x=-2*f_curr, color="red", linestyle="--") f_curr = np.sqrt((3 * np.pi) / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi) print(f_curr) plt.axvline(x=2 * f_curr, color="red", linestyle="--") plt.axvline(x=-2 * f_curr, color="red", linestyle="--") plt.axvline(x=f_curr, color="black", linestyle="--") plt.axvline(x=-f_curr, color="black", linestyle="--") plt.legend(loc=0) plt.title("Frequency Spectrum of transmitted signal") plt.xlabel("Frequency") plt.ylabel("Magnitude") plt.tight_layout() plt.savefig('encoder_spectrum.eps', format='eps') plt.show() def plot_e2e_encoded_output(model_name=None): ''' Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser. The distorted DD received signal is also plotted. @param model_name: The name of the model to import. If None, then the latest model will be imported. ''' # Load pre-trained model ae_model, params = load_model(model_name=model_name) # Generate a random block of messages val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True) # Encode and flatten the messages enc = ae_model.encoder(inp) flat_enc = layers.Flatten()(enc) chan_out = ae_model.channel.layers[1](flat_enc) # Instantiate LPF layer lpf = DigitizationLayer(fs=params["fs"], num_of_samples=params["messages_per_block"] * params["samples_per_symbol"], sig_avg=0) # Apply LPF lpf_out = lpf(flat_enc) # Time axis t = np.arange(params["messages_per_block"] * params["samples_per_symbol"]) if isinstance(ae_model.channel.layers[1], OpticalChannel): t = t / params["fs"] # Plot the concatenated symbols before and after LPF plt.figure(figsize=(1.4*params["messages_per_block"], 4.5)) for i in range(1, params["messages_per_block"]): plt.axvline(x=t[i * params["samples_per_symbol"]], color='black') plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN') plt.plot(t, lpf_out.numpy().T, label='optical field at tx') plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx') plt.xlim((t.min(), t.max())) plt.title("Sample block of transmitted/received waveform") print(str(val[0, :, 0])) # plt.text(0, 0, str(val[0, :, 0])) plt.legend(loc='upper right') plt.tight_layout() plt.savefig('sample_transmission_block.eps', format='eps') plt.show() if __name__ == '__main__': ae_model, params = load_model(model_name='20210502-135450') ae_model.view_encoder() # ae_model.view_sample_block() # plot_e2e_spectrum('20210502-135450', plot_theoretical_nulls=True) # plot_e2e_encoded_output('20210502-135450')