new_model.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import tensorflow as tf
  2. from tensorflow.keras import layers, losses
  3. from models.custom_layers import ExtractCentralMessage, OpticalChannel
  4. from models.end_to_end import EndToEndAutoencoder
  5. from models.custom_layers import BitsToSymbols, SymbolsToBits
  6. import numpy as np
  7. import math
  8. from matplotlib import pyplot as plt
  9. class BitMappingModel(tf.keras.Model):
  10. def __init__(self,
  11. cardinality,
  12. samples_per_symbol,
  13. messages_per_block,
  14. channel):
  15. super(BitMappingModel, self).__init__()
  16. # Labelled M in paper
  17. self.cardinality = cardinality
  18. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  19. # Labelled n in paper
  20. self.samples_per_symbol = samples_per_symbol
  21. # Labelled N in paper
  22. if messages_per_block % 2 == 0:
  23. messages_per_block += 1
  24. self.messages_per_block = messages_per_block
  25. self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
  26. samples_per_symbol=self.samples_per_symbol,
  27. messages_per_block=self.messages_per_block,
  28. channel=channel,
  29. bit_mapping=False)
  30. self.bit_error_rate = []
  31. self.symbol_error_rate = []
  32. def call(self, inputs, training=None, mask=None):
  33. x1 = BitsToSymbols(self.cardinality, self.messages_per_block)(inputs)
  34. x2 = self.e2e_model(x1)
  35. out = SymbolsToBits(self.cardinality)(x2)
  36. return out
  37. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  38. """
  39. """
  40. mid_idx = int((self.messages_per_block - 1) / 2)
  41. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  42. out = rand_int
  43. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  44. if return_vals:
  45. return out_arr, out_arr, out_arr[:, mid_idx, :]
  46. return out_arr, out_arr[:, mid_idx, :]
  47. def train(self, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  48. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  49. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  50. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  51. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  52. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  53. self.compile(optimizer=opt,
  54. # loss=losses.BinaryCrossentropy(),
  55. loss=losses.MeanSquaredError(),
  56. metrics=['accuracy'],
  57. loss_weights=None,
  58. weighted_metrics=None,
  59. run_eagerly=False
  60. )
  61. self.fit(x=X_train,
  62. y=y_train,
  63. batch_size=batch_size,
  64. epochs=epochs,
  65. shuffle=True,
  66. validation_data=(X_test, y_test)
  67. )
  68. def trainIterative(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  69. for i in range(int(0.5*iters)):
  70. print("Loop {}/{}".format(i, iters))
  71. self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
  72. self.e2e_model.test()
  73. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  74. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  75. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  76. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  77. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  78. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  79. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  80. self.compile(optimizer=opt,
  81. loss=losses.BinaryCrossentropy(),
  82. # loss=losses.MeanSquaredError(),
  83. metrics=['accuracy'],
  84. loss_weights=None,
  85. weighted_metrics=None,
  86. run_eagerly=False
  87. )
  88. self.fit(x=X_train,
  89. y=y_train,
  90. batch_size=batch_size,
  91. epochs=epochs,
  92. shuffle=True,
  93. validation_data=(X_test, y_test)
  94. )
  95. self.e2e_model.test()
  96. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  97. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  98. for i in range(int(0.5*iters)):
  99. X_train, y_train = self.generate_random_ inputs(int(1e5 * train_size))
  100. X_test, y_test = self.generate_random_inputs(int(1e5 * (1 - train_size)))
  101. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  102. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  103. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  104. self.compile(optimizer=opt,
  105. loss=losses.BinaryCrossentropy(),
  106. # loss=losses.MeanSquaredError(),
  107. metrics=['accuracy'],
  108. loss_weights=None,
  109. weighted_metrics=None,
  110. run_eagerly=False
  111. )
  112. self.fit(x=X_train,
  113. y=y_train,
  114. batch_size=batch_size,
  115. epochs=epochs,
  116. shuffle=True,
  117. validation_data=(X_test, y_test)
  118. )
  119. self.e2e_model.test()
  120. self.symbol_error_rate.append(self.e2e_model.symbol_error_rate)
  121. self.bit_error_rate.append(self.e2e_model.bit_error_rate)
  122. SAMPLING_FREQUENCY = 336e9
  123. CARDINALITY = 64
  124. SAMPLES_PER_SYMBOL = 48
  125. MESSAGES_PER_BLOCK = 11
  126. DISPERSION_FACTOR = -21.7 * 1e-24
  127. FIBER_LENGTH = 50
  128. ENOB = 6
  129. if __name__ == 'asd':
  130. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  131. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  132. dispersion_factor=DISPERSION_FACTOR,
  133. fiber_length=FIBER_LENGTH,
  134. sig_avg=0.5,
  135. enob=ENOB)
  136. model = BitMappingModel(cardinality=CARDINALITY,
  137. samples_per_symbol=SAMPLES_PER_SYMBOL,
  138. messages_per_block=MESSAGES_PER_BLOCK,
  139. channel=optical_channel)
  140. model.train()
  141. if __name__ == '__main__':
  142. distances = [50]
  143. ser = []
  144. ber = []
  145. baud_rate = SAMPLING_FREQUENCY / (SAMPLES_PER_SYMBOL * 1e9)
  146. bit_rate = math.log(CARDINALITY, 2) * baud_rate
  147. snr = None
  148. for d in distances:
  149. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  150. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  151. dispersion_factor=DISPERSION_FACTOR,
  152. fiber_length=d,
  153. sig_avg=0.5,
  154. enob=ENOB)
  155. model = BitMappingModel(cardinality=CARDINALITY,
  156. samples_per_symbol=SAMPLES_PER_SYMBOL,
  157. messages_per_block=MESSAGES_PER_BLOCK,
  158. channel=optical_channel)
  159. if snr is None:
  160. snr = model.e2e_model.snr
  161. elif snr != model.e2e_model.snr:
  162. print("SOMETHING IS GOING WRONG YOU BETTER HAVE A LOOK!")
  163. print("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  164. model.trainIterative(iters=20, num_of_blocks=1e3, epochs=5)
  165. model.e2e_model.test(length_plot=True)
  166. ber.append(model.bit_error_rate[-1])
  167. ser.append(model.symbol_error_rate[-1])
  168. e2e_model = EndToEndAutoencoder(cardinality=CARDINALITY,
  169. samples_per_symbol=SAMPLES_PER_SYMBOL,
  170. messages_per_block=MESSAGES_PER_BLOCK,
  171. channel=optical_channel,
  172. bit_mapping=False)
  173. ber1 = []
  174. ser1 = []
  175. for i in range(int(len(model.bit_error_rate))):
  176. e2e_model.train(num_of_blocks=1e3, epochs=5)
  177. e2e_model.test()
  178. ber1.append(e2e_model.bit_error_rate)
  179. ser1.append(e2e_model.symbol_error_rate)
  180. # model2 = BitMappingModel(cardinality=CARDINALITY,
  181. # samples_per_symbol=SAMPLES_PER_SYMBOL,
  182. # messages_per_block=MESSAGES_PER_BLOCK,
  183. # channel=optical_channel)
  184. #
  185. # ber2 = []
  186. # ser2 = []
  187. #
  188. # for i in range(int(len(model.bit_error_rate) / 2)):
  189. # model2.train(num_of_blocks=1e3, epochs=5)
  190. # model2.e2e_model.test()
  191. #
  192. # ber2.append(model2.e2e_model.bit_error_rate)
  193. # ser2.append(model2.e2e_model.symbol_error_rate)
  194. plt.plot(ber1, label='BER (1)')
  195. # plt.plot(ser1, label='SER (1)')
  196. # plt.plot(np.arange(0, len(ber2), 1) * 2, ber2, label='BER (2)')
  197. # plt.plot(np.arange(0, len(ser2), 1) * 2, ser2, label='SER (2)')
  198. plt.plot(model.bit_error_rate, label='BER (3)')
  199. # plt.plot(model.symbol_error_rate, label='SER (3)')
  200. plt.title("{:.2f} Gbps along {}km of fiber with an SNR of {:.2f}".format(bit_rate, d, snr))
  201. plt.yscale('log')
  202. plt.legend()
  203. plt.show()
  204. # model.summary()
  205. # plt.plot(ber, label='BER')
  206. # plt.plot(ser, label='SER')
  207. # plt.title("BER for different lengths at {:.2f} Gbps with an SNR of {:.2f}".format(bit_rate, snr))
  208. # plt.legend(ber)