plots.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from sklearn.preprocessing import OneHotEncoder
  2. import numpy as np
  3. from tensorflow.keras import layers
  4. from end_to_end import load_model
  5. from models.layers import DigitizationLayer, OpticalChannel
  6. from matplotlib import pyplot as plt
  7. import math
  8. # plot frequency spectrum of e2e model
  9. def plot_e2e_spectrum(model_name=None):
  10. # Load pre-trained model
  11. ae_model, params = load_model(model_name=model_name)
  12. # Generate a list of random symbols (one hot encoded)
  13. cat = [np.arange(params["cardinality"])]
  14. enc = OneHotEncoder(handle_unknown='ignore', sparse=False, categories=cat)
  15. rand_int = np.random.randint(params["cardinality"], size=(10000, 1))
  16. out = enc.fit_transform(rand_int)
  17. # Encode the list of symbols using the trained encoder
  18. a = ae_model.encode_stream(out).flatten()
  19. # Pass the output of the encoder through LPF
  20. lpf = DigitizationLayer(fs=params["fs"],
  21. num_of_samples=320000,
  22. sig_avg=0)(a).numpy()
  23. # Plot the frequency spectrum of the signal
  24. freq = np.fft.fftfreq(lpf.shape[-1], d=1 / params["fs"])
  25. mul = np.exp(0.5j * params["dispersion_factor"] * params["fiber_length"] * np.power(2 * math.pi * freq, 2))
  26. a = np.fft.ifft(mul)
  27. a2 = np.power(a, 2)
  28. b = np.abs(np.fft.fft(a2))
  29. plt.plot(freq, np.fft.fft(lpf), 'x')
  30. plt.ylim((-500, 500))
  31. plt.xlim((-5e10, 5e10))
  32. plt.show()
  33. # plt.plot(freq, np.fft.fft(lpf), 'x')
  34. plt.plot(freq, b)
  35. plt.ylim((-500, 500))
  36. plt.xlim((-5e10, 5e10))
  37. plt.show()
  38. def plot_e2e_encoded_output(model_name=None):
  39. # Load pre-trained model
  40. ae_model, params = load_model(model_name=model_name)
  41. # Generate a random block of messages
  42. val, inp, _ = ae_model.generate_random_inputs(num_of_blocks=1, return_vals=True)
  43. # Encode and flatten the messages
  44. enc = ae_model.encoder(inp)
  45. flat_enc = layers.Flatten()(enc)
  46. chan_out = ae_model.channel.layers[1](flat_enc)
  47. # Instantiate LPF layer
  48. lpf = DigitizationLayer(fs=params["fs"],
  49. num_of_samples=params["messages_per_block"] * params["samples_per_symbol"],
  50. sig_avg=0)
  51. # Apply LPF
  52. lpf_out = lpf(flat_enc)
  53. # Time axis
  54. t = np.arange(params["messages_per_block"] * params["samples_per_symbol"])
  55. if isinstance(ae_model.channel.layers[1], OpticalChannel):
  56. t = t / params["fs"]
  57. # Plot the concatenated symbols before and after LPF
  58. plt.figure(figsize=(2 * params["messages_per_block"], 6))
  59. for i in range(1, params["messages_per_block"]):
  60. plt.axvline(x=t[i * params["samples_per_symbol"]], color='black')
  61. plt.axhline(y=0, color='black')
  62. plt.plot(t, flat_enc.numpy().T, 'x', label='output of encNN')
  63. plt.plot(t, lpf_out.numpy().T, label='optical field at tx')
  64. plt.plot(t, chan_out.numpy().flatten(), label='optical field at rx')
  65. plt.ylim((-0.1, 0.1))
  66. plt.xlim((t.min(), t.max()))
  67. plt.title(str(val[0, :, 0]))
  68. plt.legend(loc='upper right')
  69. plt.show()
  70. if __name__ == '__main__':
  71. # plot_e2e_spectrum()
  72. plot_e2e_encoded_output()