plots.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from sklearn.preprocessing import OneHotEncoder
  2. import numpy as np
  3. from tensorflow.keras import layers
  4. from end_to_end import load_model
  5. from models.custom_layers import DigitizationLayer, OpticalChannel
  6. from matplotlib import pyplot as plt
  7. import math
  8. def plot_e2e_spectrum(model_name=None, num_samples=10000, plot_theoretical_nulls=False):
  9. '''
  10. Plot frequency spectrum of the output signal at the encoder
  11. @param model_name: The name of the model to import. If None, then the latest model will be imported.
  12. @param num_samples: The number of symbols to simulate when computing the spectrum.
  13. '''
  14. # Load pre-trained model
  15. ae_model, params = load_model(model_name=model_name)
  16. # Generate a list of random symbols (one hot encoded)
  17. cat = [np.arange(params["cardinality"])]
  18. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  19. rand_int = np.random.randint(params["cardinality"], size=(num_samples, 1))
  20. out = enc.fit_transform(rand_int)
  21. # Encode the list of symbols using the trained encoder
  22. a = ae_model.encode_stream(out).flatten()
  23. # Pass the output of the encoder through LPF
  24. lpf = DigitizationLayer(fs=params["fs"],
  25. num_of_samples=params["samples_per_symbol"] * num_samples,
  26. sig_avg=0)(a).numpy()
  27. # Plot the frequency spectrum of the signal
  28. freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
  29. plt.plot(freq, np.abs(np.fft.fft(lpf)), 'x')
  30. plt.ylim((0, 500))
  31. plt.xlim((0, 4e10))
  32. if plot_theoretical_nulls:
  33. f_curr = np.sqrt(np.pi / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
  34. i = 1
  35. print(f_curr)
  36. plt.axvline(x=f_curr, color="black", linestyle="--", label="null frequency")
  37. plt.axvline(x=-f_curr, color="black", linestyle="--")
  38. plt.axvline(x=2*f_curr, color="red", linestyle="--", label="double of null frequency")
  39. plt.axvline(x=-2*f_curr, color="red", linestyle="--")
  40. f_curr = np.sqrt((3 * np.pi) / (-params["dispersion_factor"] * params["fiber_length"])) / (2 * np.pi)
  41. print(f_curr)
  42. plt.axvline(x=2 * f_curr, color="red", linestyle="--")
  43. plt.axvline(x=-2 * f_curr, color="red", linestyle="--")
  44. plt.axvline(x=f_curr, color="black", linestyle="--")
  45. plt.axvline(x=-f_curr, color="black", linestyle="--")
  46. plt.legend(loc=0)
  47. plt.title("Frequency Spectrum of transmitted signal")
  48. plt.xlabel("Frequency")
  49. plt.ylabel("Magnitude")
  50. plt.tight_layout()
  51. plt.savefig('encoder_spectrum.eps', format='eps')
  52. plt.show()
  53. def plot_e2e_encoded_output(model_name=None):
  54. '''
  55. Plots the raw outputs of the encoder neural network as well as the voltage potential that modulates the laser.
  56. The distorted DD received signal is also plotted.
  57. @param model_name: The name of the model to import. If None, then the latest model will be imported.
  58. '''
  59. # Load pre-trained model
  60. ae_model, params = load_model(model_name=model_name)
  61. # Generate a random block of messages
  62. val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
  63. # Encode and flatten the messages
  64. enc = ae_model.encoder(inp)
  65. flat_enc = layers.Flatten()(enc)
  66. chan_out = ae_model.channel.layers[1](flat_enc)
  67. # Instantiate LPF layer
  68. lpf = DigitizationLayer(fs=params["fs"],
  69. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  70. sig_avg=0)
  71. # Apply LPF
  72. lpf_out = lpf(flat_enc)
  73. # Time axis
  74. t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
  75. if isinstance(ae_model.channel.layers[1], OpticalChannel):
  76. t = t / params["fs"]
  77. # Plot the concatenated symbols before and after LPF
  78. plt.figure(figsize=(1.4*params["messages_per_block"], 4.5))
  79. for i in range(1, params["messages_per_block"]):
  80. plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
  81. plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
  82. plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
  83. plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
  84. plt.xlim((t.min(), t.max()))
  85. plt.title("Sample block of transmitted/received waveform")
  86. print(str(val[0, :, 0]))
  87. # plt.text(0, 0, str(val[0, :, 0]))
  88. plt.legend(loc='upper right')
  89. plt.tight_layout()
  90. plt.savefig('sample_transmission_block.eps', format='eps')
  91. plt.show()
  92. if __name__ == '__main__':
  93. ae_model, params = load_model(model_name='20210502-135450')
  94. ae_model.view_encoder()
  95. # ae_model.view_sample_block()
  96. #
  97. plot_e2e_spectrum('20210502-135450', plot_theoretical_nulls=True)
  98. # plot_e2e_encoded_output('20210502-135450')