misc_test.py 385 B

123456789101112131415
  1. import misc
  2. import numpy as np
  3. def test_bit_matrix_one_hot():
  4. for n in range(2, 8):
  5. x0 = misc.generate_random_bit_array(100 * n)
  6. x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
  7. x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
  8. assert np.array_equal(x0, x2)
  9. if __name__ == "__main__":
  10. test_bit_matrix_one_hot()
  11. print("Everything passed")