| 123456789101112131415 |
- import misc
- import numpy as np
- def test_bit_matrix_one_hot():
- for n in range(2, 8):
- x0 = misc.generate_random_bit_array(100 * n)
- x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
- x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
- assert np.array_equal(x0, x2)
- if __name__ == "__main__":
- test_bit_matrix_one_hot()
- print("Everything passed")
|