ソースを参照

Generate compare samples script

Min 4 年 前
コミット
c3cf999c1d
1 ファイル変更34 行追加13 行削除
  1. 34 13
      scripts/fpu_test_gen.py

+ 34 - 13
scripts/fpu_test_gen.py

@@ -1,6 +1,6 @@
 import numpy as np
 import os
-import struct
+import sys
 
 
 def reverse_endian(data):
@@ -10,34 +10,55 @@ def reverse_endian(data):
     return bytes(result)
 
 
-def generate_fp_vector(cases, filename, dtype=np.float16, big_endian=False):
-    dsize = 0
+def generate_numbers(cases, dtype=np.float16):
     if dtype == np.float16:
         dsize = 2
     elif dtype == np.float32:
         dsize = 4
     else:
         raise ValueError(f"Unknown dtype {dtype}")
-
     x = np.frombuffer(os.urandom(cases * dsize), dtype=dtype)
     y = np.frombuffer(os.urandom(cases * dsize), dtype=dtype)
+    return x, y, dsize
+
+
+def generate_fp_vector(cases, filename, dtype=np.float16, big_endian=False, comp_file=None):
+    x, y, dsize = generate_numbers(cases, dtype)
     np.seterr(all='ignore')
     sum = x + y
     mul = x * y
-    x = x.tobytes()
-    y = y.tobytes()
-    sum = sum.tobytes()
-    mul = mul.tobytes()
     with open(filename, 'w') as f:
         for i in range(cases):
             t = lambda v: reverse_endian(v) if big_endian else v
             f.write(' '.join([
-                t(x[i * dsize:i * dsize + dsize]).hex(),
-                t(y[i * dsize:i * dsize + dsize]).hex(),
-                t(sum[i * dsize:i * dsize + dsize]).hex(),
-                t(mul[i * dsize:i * dsize + dsize]).hex(),
+                t(x.tobytes()[i * dsize:i * dsize + dsize]).hex(),
+                t(y.tobytes()[i * dsize:i * dsize + dsize]).hex(),
+                t(sum.tobytes()[i * dsize:i * dsize + dsize]).hex(),
+                t(mul.tobytes()[i * dsize:i * dsize + dsize]).hex(),
             ]) + '\n')
+    if comp_file is not None:
+        gt = x > y
+        lt = x < y
+        ge = x >= y
+        le = x <= y
+        eq = x == y
+        with open(comp_file, 'w') as f:
+            for i in range(cases):
+                f.write(''.join([
+                    '1' if gt[i] else '0',
+                    '1' if lt[i] else '0',
+                    '1' if ge[i] else '0',
+                    '1' if le[i] else '0',
+                    '1' if eq[i] else '0'
+                ]) + f'  // {x[i]:10.3e} {y[i]:10.3e}\n')
 
 
 if __name__ == '__main__':
-    generate_fp_vector(30, 'fp16_test.hex', dtype=np.float16, big_endian=True)
+    generate = 50
+    if len(sys.argv) == 2 and sys.argv[1].isdigit():
+        generate = int(sys.argv[1])
+    else:
+        print(f"Usage: {sys.argv[0]} [number of generated tests]")
+
+    generate_fp_vector(generate, 'fp16_test.hex', dtype=np.float16, big_endian=True)
+    generate_fp_vector(generate, 'fp32_test.hex', dtype=np.float32, big_endian=True, comp_file='fp32_test_comp.hex')