import misc import numpy as np import math import itertools import tensorflow as tf from models.custom_layers import BitsToSymbols, SymbolsToBits, OpticalChannel from matplotlib import pyplot as plt def test_bit_matrix_one_hot(): for n in range(2, 8): x0 = misc.generate_random_bit_array(100 * n) x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n))) x2 = misc.one_hot2bit_matrix(x1).reshape((-1,)) assert np.array_equal(x0, x2) if __name__ == "__main__": # cardinality = 8 # messages_per_block = 3 # num_of_blocks = 10 # bits_per_symbol = 3 # # #----------------------------------- # # mid_idx = int((messages_per_block - 1) / 2) # # ################################################################################################################ # # # rand_int = np.random.randint(self.cardinality, size=(num_of_blocks * self.messages_per_block, 1)) # rand_int = np.random.randint(2, size=(num_of_blocks * messages_per_block * bits_per_symbol, 1)) # # # out = enc.fit_transform(rand_int) # out = rand_int # # # out_arr = np.reshape(out, (num_of_blocks, self.messages_per_block, self.cardinality)) # out_arr = np.reshape(out, (num_of_blocks, messages_per_block, bits_per_symbol)) # # out_arr_tf = tf.convert_to_tensor(out_arr, dtype=tf.float32) # # # n = int(math.log(cardinality, 2)) # pows = tf.convert_to_tensor(np.power(2, np.linspace(n - 1, 0, n)).reshape(-1, 1), dtype=tf.float32) # # pows_np = pows.numpy() # # a = np.asarray([0, 1, 1]).reshape(1, -1) # # b = tf.tensordot(out_arr_tf, pows, axes=1).numpy() SAMPLING_FREQUENCY = 336e9 CARDINALITY = 32 SAMPLES_PER_SYMBOL = 100 NUM_OF_SYMBOLS = 10 DISPERSION_FACTOR = -21.7 * 1e-24 FIBER_LENGTH = 50 optical_channel = OpticalChannel(fs=SAMPLING_FREQUENCY, num_of_samples=NUM_OF_SYMBOLS * SAMPLES_PER_SYMBOL, dispersion_factor=DISPERSION_FACTOR, fiber_length=FIBER_LENGTH, rx_stddev=0, q_stddev=0) inp = np.random.randint(4, size=(NUM_OF_SYMBOLS, )) inp_t = np.repeat(inp, SAMPLES_PER_SYMBOL).reshape(1, -1) plt.plot(inp_t.flatten()) out_tf = optical_channel(inp_t) out_np = out_tf.numpy() plt.plot(out_np.flatten()) plt.show() pass