new_model.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import tensorflow as tf
  2. from tensorflow.keras import losses
  3. from layers import OpticalChannel, BitsToSymbols, SymbolsToBits
  4. from end_to_end import EndToEndAutoencoder
  5. import numpy as np
  6. import math
  7. from matplotlib import pyplot as plt
  8. class BitMappingModel(tf.keras.Model):
  9. def __init__(self,
  10. cardinality,
  11. samples_per_symbol,
  12. messages_per_block,
  13. channel):
  14. super(BitMappingModel, self).__init__()
  15. # Labelled M in paper
  16. self.cardinality = cardinality
  17. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  18. # Labelled n in paper
  19. self.samples_per_symbol = samples_per_symbol
  20. # Labelled N in paper
  21. if messages_per_block % 2 == 0:
  22. messages_per_block += 1
  23. self.messages_per_block = messages_per_block
  24. self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
  25. samples_per_symbol=self.samples_per_symbol,
  26. messages_per_block=self.messages_per_block,
  27. channel=channel,
  28. bit_mapping=False)
  29. self.bit_error_rate = []
  30. self.symbol_error_rate = []
  31. def call(self, inputs, training=None, mask=None):
  32. x1 = BitsToSymbols(self.cardinality)(inputs)
  33. x2 = self.e2e_model(x1)
  34. out = SymbolsToBits(self.cardinality)(x2)
  35. return out
  36. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  37. """
  38. """
  39. mid_idx = int((self.messages_per_block - 1) / 2)
  40. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  41. out = rand_int
  42. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  43. if return_vals:
  44. return out_arr, out_arr, out_arr[:, mid_idx, :]
  45. return out_arr, out_arr[:, mid_idx, :]
  46. def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  47. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  48. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  49. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  50. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  51. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  52. self.compile(optimizer=opt,
  53. loss=losses.BinaryCrossentropy(),
  54. metrics=['accuracy'],
  55. loss_weights=None,
  56. weighted_metrics=None,
  57. run_eagerly=False
  58. )
  59. self.fit(x=X_train,
  60. y=y_train,
  61. batch_size=batch_size,
  62. epochs=epochs,
  63. shuffle=True,
  64. validation_data=(X_test, y_test)
  65. )
  66. def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  67. for _ in range(iters):
  68. self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
  69. self.e2e_model.test()
  70. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  71. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  72. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  73. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  74. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  75. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  76. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  77. self.compile(optimizer=opt,
  78. loss=losses.BinaryCrossentropy(),
  79. metrics=['accuracy'],
  80. loss_weights=None,
  81. weighted_metrics=None,
  82. run_eagerly=False
  83. )
  84. self.fit(x=X_train,
  85. y=y_train,
  86. batch_size=batch_size,
  87. epochs=epochs,
  88. shuffle=True,
  89. validation_data=(X_test, y_test)
  90. )
  91. self.e2e_model.test()
  92. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  93. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  94. SAMPLING_FREQUENCY = 336e9
  95. CARDINALITY = 32
  96. SAMPLES_PER_SYMBOL = 32
  97. MESSAGES_PER_BLOCK = 9
  98. DISPERSION_FACTOR = -21.7 * 1e-24
  99. FIBER_LENGTH = 50
  100. if __name__ == '__main__':
  101. distances = [0, 10, 20, 30, 40, 50, 60]
  102. ser = []
  103. ber = []
  104. baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
  105. bit_rate = math.log(CARDINALITY, 2) * baud_rate
  106. snr = None
  107. for d in distances:
  108. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  109. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  110. dispersion_factor=DISPERSION_FACTOR,
  111. fiber_length=d)
  112. model = BitMappingModel(cardinality=CARDINALITY,
  113. samples_per_symbol=SAMPLES_PER_SYMBOL,
  114. messages_per_block=MESSAGES_PER_BLOCK,
  115. channel=optical_channel)
  116. if snr is None:
  117. snr = model.e2e_model.snr
  118. elif snr != model.e2e_model.snr:
  119. print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
  120. # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  121. model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
  122. ber.append(model.bit_error_rate[-1])
  123. ser.append(model.symbol_error_rate[-1])
  124. # plt.plot(model.bit_error_rate, label='BER')
  125. # plt.plot(model.symbol_error_rate, label='SER')
  126. # plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  127. # plt.legend()
  128. # plt.show()
  129. # model.summary()
  130. # plt.plot(ber, label='BER')
  131. # plt.plot(ser, label='SER')
  132. # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
  133. # plt.legend(ber)