new_model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import tensorflow as tf
  2. from tensorflow.keras import layers, losses
  3. from models.custom_layers import ExtractCentralMessage, OpticalChannel
  4. from models.end_to_end import EndToEndAutoencoder
  5. from models.custom_layers import BitsToSymbols, SymbolsToBits
  6. import numpy as np
  7. import math
  8. from matplotlib import pyplot as plt
  9. class BitMappingModel(tf.keras.Model):
  10. def __init__(self,
  11. cardinality,
  12. samples_per_symbol,
  13. messages_per_block,
  14. channel):
  15. super(BitMappingModel, self).__init__()
  16. # Labelled M in paper
  17. self.cardinality = cardinality
  18. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  19. # Labelled n in paper
  20. self.samples_per_symbol = samples_per_symbol
  21. # Labelled N in paper
  22. if messages_per_block % 2 == 0:
  23. messages_per_block += 1
  24. self.messages_per_block = messages_per_block
  25. self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
  26. samples_per_symbol=self.samples_per_symbol,
  27. messages_per_block=self.messages_per_block,
  28. channel=channel,
  29. bit_mapping=False)
  30. self.bit_error_rate = []
  31. self.symbol_error_rate = []
  32. def call(self, inputs, training=None, mask=None):
  33. x1 = BitsToSymbols(self.cardinality)(inputs)
  34. x2 = self.e2e_model(x1)
  35. out = SymbolsToBits(self.cardinality)(x2)
  36. return out
  37. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  38. """
  39. """
  40. mid_idx = int((self.messages_per_block - 1) / 2)
  41. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  42. out = rand_int
  43. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  44. if return_vals:
  45. return out_arr, out_arr, out_arr[:, mid_idx, :]
  46. return out_arr, out_arr[:, mid_idx, :]
  47. def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  48. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  49. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  50. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  51. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  52. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  53. self.compile(optimizer=opt,
  54. loss=losses.BinaryCrossentropy(),
  55. metrics=['accuracy'],
  56. loss_weights=None,
  57. weighted_metrics=None,
  58. run_eagerly=False
  59. )
  60. self.fit(x=X_train,
  61. y=y_train,
  62. batch_size=batch_size,
  63. epochs=epochs,
  64. shuffle=True,
  65. validation_data=(X_test, y_test)
  66. )
  67. def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  68. for _ in range(iters):
  69. self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
  70. self.e2e_model.test()
  71. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  72. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  73. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  74. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  75. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  76. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  77. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  78. self.compile(optimizer=opt,
  79. loss=losses.BinaryCrossentropy(),
  80. metrics=['accuracy'],
  81. loss_weights=None,
  82. weighted_metrics=None,
  83. run_eagerly=False
  84. )
  85. self.fit(x=X_train,
  86. y=y_train,
  87. batch_size=batch_size,
  88. epochs=epochs,
  89. shuffle=True,
  90. validation_data=(X_test, y_test)
  91. )
  92. self.e2e_model.test()
  93. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  94. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  95. SAMPLING_FREQUENCY = 336e9
  96. CARDINALITY = 32
  97. SAMPLES_PER_SYMBOL = 32
  98. MESSAGES_PER_BLOCK = 9
  99. DISPERSION_FACTOR = -21.7 * 1e-24
  100. FIBER_LENGTH = 50
  101. if __name__ == '__main__':
  102. distances = [0, 10, 20, 30, 40, 50, 60]
  103. ser = []
  104. ber = []
  105. baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
  106. bit_rate = math.log(CARDINALITY, 2) * baud_rate
  107. snr = None
  108. for d in distances:
  109. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  110. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  111. dispersion_factor=DISPERSION_FACTOR,
  112. fiber_length=d)
  113. model = BitMappingModel(cardinality=CARDINALITY,
  114. samples_per_symbol=SAMPLES_PER_SYMBOL,
  115. messages_per_block=MESSAGES_PER_BLOCK,
  116. channel=optical_channel)
  117. if snr is None:
  118. snr = model.e2e_model.snr
  119. elif snr != model.e2e_model.snr:
  120. print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
  121. # print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  122. model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
  123. ber.append(model.bit_error_rate[-1])
  124. ser.append(model.symbol_error_rate[-1])
  125. # plt.plot(model.bit_error_rate, label='BER')
  126. # plt.plot(model.symbol_error_rate, label='SER')
  127. # plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  128. # plt.legend()
  129. # plt.show()
  130. # model.summary()
  131. # plt.plot(ber, label='BER')
  132. # plt.plot(ser, label='SER')
  133. # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
  134. # plt.legend(ber)