Min 5 лет назад
Родитель
Сommit
7151c964d3
2 измененных файлов с 49 добавлено и 0 удалено
  1. 34 0
      misc.py
  2. 15 0
      tests/misc_test.py

+ 34 - 0
misc.py

@@ -20,6 +20,40 @@ def display_alphabet(alphabet, values=None, a_vals=False, title="Alphabet conste
     plt.show()
 
 
+def bit_matrix2one_hot(matrix: np.ndarray) -> np.ndarray:
+    """
+    Returns a copy of bit encoded matrix to one hot matrix. A row examples:
+    [1010] (decimal 10) => [0000 0100 0000 0000]
+    [0011] (decimal  3) => [0000 0000 0000 1000]
+    each number represents true/false value in column
+    """
+    N = matrix.shape[1]
+    encoder = 2**np.arange(N)
+    values = np.dot(matrix, encoder)
+    result = np.zeros((matrix.shape[0], 2**N), dtype=bool)
+    result[np.arange(matrix.shape[0]), values] = True
+    return result
+
+
+def one_hot2bit_matrix(matrix: np.ndarray) -> np.ndarray:
+    """
+    Returns a copy of one hot matrix to bit encoded matrix. A row examples:
+    [0000 0100 0000 0000] => [1010] (decimal 10)
+    [0000 0000 0000 1000] => [0011] (decimal  3)
+    each number represents true/false value in column
+    """
+    N = math.ceil(math.log2(matrix.shape[1]))
+    values = np.dot(matrix, np.arange(2**N))
+    return int2bit_array(values, N)
+
+
+def int2bit_array(int_arr: np.ndarray, N: int) -> np.ndarray:
+    x0 = np.array([int_arr], dtype=np.uint8)
+    x1 = np.unpackbits(x0.T, bitorder='little', axis=1)
+    result = x1[:, :N].astype(bool)  # , indices
+    return result
+
+
 def polar2rect(array, amp_column=0, phase_column=1) -> np.ndarray:
     """
     Return copy of array with amp_column and phase_column as polar coordinates replaced by rectangular coordinates

+ 15 - 0
tests/misc_test.py

@@ -0,0 +1,15 @@
+import misc
+import numpy as np
+
+
+def test_bit_matrix_one_hot():
+    for n in range(2, 8):
+        x0 = misc.generate_random_bit_array(100 * n)
+        x1 = misc.bit_matrix2one_hot(x0.reshape((-1, n)))
+        x2 = misc.one_hot2bit_matrix(x1).reshape((-1,))
+        assert np.array_equal(x0, x2)
+
+
+if __name__ == "__main__":
+    test_bit_matrix_one_hot()
+    print("Everything passed")