| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import misc
- import numpy as np
- import math
- import itertools
- import tensorflow as tf
- from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel
- from matplotlib import pyplot as plt
- def test_bit_matrix_one_hot():
- for n in range(2, 8):
- x0 = misc.generate_random_bit_array(100 * n)
- x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
- x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
- assert np.array_equal(x0, x2)
- if __name__ == "__main__":
- # cardinality = 8
- # messages_per_block = 3
- # num_of_blocks = 10
- # bits_per_symbol = 3
- #
- # #-----------------------------------
- #
- # mid_idx = int((messages_per_block - 1) / 2)
- #
- # ################################################################################################################
- #
- # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1))
- # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1))
- #
- # # out = enc.fit_transform(rand_int)
- # out = rand_int
- #
- # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality))
- # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol))
- #
- # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32)
- #
- #
- # n = int(math.log(cardinality, 2))
- # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32)
- #
- # pows_np = pows.numpy()
- #
- # a = np.asarray([0, 1, 1]).reshape(1, -1)
- #
- # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy()
- SAMPLING_FREQUENCY = 336e9
- CARDINALITY = 32
- SAMPLES_PER_SYMBOL = 100
- NUM_OF_SYMBOLS = 10
- DISPERSION_FACTOR = -21.7 * 1e-24
- FIBER_LENGTH = 50
- optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY,
- num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL,
- dispersion_factor=DISPERSION_FACTOR,
- fiber_length=FIBER_LENGTH,
- rx_stddev=0,
- q_stddev=0)
- inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, ))
- inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1)
- plt.plot(inp_t.flatten())
- out_tf = optical_channel(inp_t)
- out_np = out_tf.numpy()
- plt.plot(out_np.flatten())
- plt.show()
- pass
|