new_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import tensorflow as tf
  2. from tensorflow.keras import layers, losses
  3. from models.custom_layers import ExtractCentralMessage, OpticalChannel
  4. from models.end_to_end import EndToEndAutoencoder
  5. from models.custom_layers import BitsToSymbols, SymbolsToBits
  6. import numpy as np
  7. import math
  8. class BitMappingModel(tf.keras.Model):
  9. def __init__(self,
  10. cardinality,
  11. samples_per_symbol,
  12. messages_per_block,
  13. channel):
  14. super(BitMappingModel, self).__init__()
  15. # Labelled M in paper
  16. self.cardinality = cardinality
  17. self.bits_per_symbol = int(math.log(self.cardinality, 2))
  18. # Labelled n in paper
  19. self.samples_per_symbol = samples_per_symbol
  20. # Labelled N in paper
  21. if messages_per_block % 2 == 0:
  22. messages_per_block += 1
  23. self.messages_per_block = messages_per_block
  24. self.e2e_model = EndToEndAutoencoder(cardinality=self.cardinality,
  25. samples_per_symbol=self.samples_per_symbol,
  26. messages_per_block=self.messages_per_block,
  27. channel=channel,
  28. bit_mapping=False)
  29. def call(self, inputs, training=None, mask=None):
  30. x1 = BitsToSymbols(self.cardinality)(inputs)
  31. x2 = self.e2e_model(x1)
  32. out = SymbolsToBits(self.cardinality)(x2)
  33. return out
  34. def generate_random_inputs(self, num_of_blocks, return_vals=False):
  35. """
  36. """
  37. mid_idx = int((self.messages_per_block - 1) / 2)
  38. rand_int = np.random.randint(2, size=(num_of_blocks * self.messages_per_block * self.bits_per_symbol, 1))
  39. out = rand_int
  40. out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.bits_per_symbol))
  41. if return_vals:
  42. return out_arr, out_arr, out_arr[:, mid_idx, :]
  43. return out_arr, out_arr[:, mid_idx, :]
  44. def train(self, iters=1, epochs=1, num_of_blocks=1e6, batch_size=None, train_size=0.8, lr=1e-3):
  45. """
  46. """
  47. for _ in range(iters):
  48. self.e2e_model.train(num_of_blocks=num_of_blocks, epochs=epochs)
  49. self.e2e_model.test()
  50. X_train, y_train = self.generate_random_inputs(int(num_of_blocks * train_size))
  51. X_test, y_test = self.generate_random_inputs(int(num_of_blocks * (1 - train_size)))
  52. X_train = tf.convert_to_tensor(X_train, dtype=tf.float32)
  53. X_test = tf.convert_to_tensor(X_test, dtype=tf.float32)
  54. opt = tf.keras.optimizers.Adam(learning_rate=lr)
  55. self.compile(optimizer=opt,
  56. loss=losses.BinaryCrossentropy(),
  57. metrics=['accuracy'],
  58. loss_weights=None,
  59. weighted_metrics=None,
  60. run_eagerly=False
  61. )
  62. self.fit(x=X_train,
  63. y=y_train,
  64. batch_size=batch_size,
  65. epochs=epochs,
  66. shuffle=True,
  67. validation_data=(X_test, y_test)
  68. )
  69. def test(self, num_of_blocks=1e4):
  70. pass
  71. SAMPLING_FREQUENCY = 336e9
  72. CARDINALITY = 32
  73. SAMPLES_PER_SYMBOL = 24
  74. MESSAGES_PER_BLOCK = 9
  75. DISPERSION_FACTOR = -21.7 * 1e-24
  76. FIBER_LENGTH = 50
  77. if __name__ == '__main__':
  78. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  79. num_of_samples=MESSAGES_PER_BLOCK * SAMPLES_PER_SYMBOL,
  80. dispersion_factor=DISPERSION_FACTOR,
  81. fiber_length=FIBER_LENGTH)
  82. model = BitMappingModel(cardinality=CARDINALITY,
  83. samples_per_symbol=SAMPLES_PER_SYMBOL,
  84. messages_per_block=MESSAGES_PER_BLOCK,
  85. channel=optical_channel)
  86. # a , c = model.generate_random_inputs(num_of_blocks=1)
  87. #
  88. # a = tf.convert_to_tensor(a, dtype=tf.float32)
  89. # b = model(a)
  90. model.train(iters=1, num_of_blocks=1e4, epochs=1)
  91. model.summary()