misc_test.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import misc
  2. import numpy as np
  3. import math
  4. import itertools
  5. import tensorflow as tf
  6. from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
  7. from matplotlib import pyplot as plt
  8. def test_bit_matrix_one_hot():
  9. for n in range(2, 8):
  10. x0 = misc.generate_random_bit_array(100 * n)
  11. x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
  12. x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
  13. assert np.array_equal(x0, x2)
  14. if __name__ == "__main__":
  15. # cardinality = 8
  16. # messages_per_block = 3
  17. # num_of_blocks = 10
  18. # bits_per_symbol = 3
  19. #
  20. # #-----------------------------------
  21. #
  22. # mid_idx = int((messages_per_block - 1) / 2)
  23. #
  24. # ################################################################################################################
  25. #
  26. # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
  27. # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
  28. #
  29. # # out = enc.fit_transform(rand_int)
  30. # out = rand_int
  31. #
  32. # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
  33. # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
  34. #
  35. # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
  36. #
  37. #
  38. # n = int(math.log(cardinality, 2))
  39. # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
  40. #
  41. # pows_np = pows.numpy()
  42. #
  43. # a = np.asarray([0, 1, 1]).reshape(1, -1)
  44. #
  45. # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
  46. SAMPLING_FREQUENCY = 336e9
  47. CARDINALITY = 32
  48. SAMPLES_PER_SYMBOL = 100
  49. NUM_OF_SYMBOLS = 10
  50. DISPERSION_FACTOR = -21.7 * 1e-24
  51. FIBER_LENGTH = 50
  52. optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
  53. num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
  54. dispersion_factor=DISPERSION_FACTOR,
  55. fiber_length=FIBER_LENGTH,
  56. rx_stddev=0,
  57. q_stddev=0)
  58. inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
  59. inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
  60. plt.plot(inp_t.flatten())
  61. out_tf = optical_channel(inp_t)
  62. out_np = out_tf.numpy()
  63. plt.plot(out_np.flatten())
  64. plt.show()
  65. pass