瀏覽代碼

Improved fpu adder

Min 4 年之前
父節點
當前提交
e35aadc65c
共有 4 個文件被更改,包括 208 次插入202 次删除
  1. 2 1
      scripts/fpu_test_gen.py
  2. 25 11
      simulation/modelsim/wave_fpu16_tb.do
  3. 167 189
      src/fpu16/fp_adder.sv
  4. 14 1
      src/fpu16/fpu16.sv

+ 2 - 1
scripts/fpu_test_gen.py

@@ -18,6 +18,7 @@ def dtype_size(dtype):
     else:
         raise ValueError(f"Unknown dtype {dtype}")
 
+
 def generate_numbers(cases, dtype=np.float16):
     dsize = dtype_size(dtype)
     x = np.frombuffer(os.urandom(cases * dsize), dtype=dtype)
@@ -38,7 +39,7 @@ def generate_fp_vector(cases, filename, dtype=np.float16, big_endian=False, comp
                 t(y.tobytes()[i * dsize:i * dsize + dsize]).hex(),
                 t(sum.tobytes()[i * dsize:i * dsize + dsize]).hex(),
                 t(mul.tobytes()[i * dsize:i * dsize + dsize]).hex(),
-            ]) + '\n')
+            ]) + f'  // {x[i]:10.6e} {y[i]:10.6e} {sum[i]:10.6e} {mul[i]:10.6e}\n')
     if comp_file is not None:
         gt = x > y
         lt = x < y

File diff suppressed because it is too large
+ 25 - 11
simulation/modelsim/wave_fpu16_tb.do


+ 167 - 189
src/fpu16/fp_adder.sv

@@ -1,190 +1,168 @@
-module fp_adder #(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
-	input reg [N-1:0] input_a, input_b;
-	input logic clk, reset;
-	output reg [N-1:0] output_z;
-	
-	reg [N-2-M:0] a_m, b_m, z_m; // mantissa
-	reg [N-2-M:0] a_m1, b_m1; 
-	
-	reg [M-1:0] a_e, b_e, z_e, z_e1; // exponent
-	reg a_s, b_s, a_s1, b_s1, z_s, z_s1; // sign
-	
-	reg [1:0] greater; // 01 for a, 10 for b, 11 for both and 00 for neither
-	reg [M:0] abs; // For the absolute difference between exponents
-	
-	always_ff @(posedge clk)
-	begin
-		if(~reset)
-		begin
-				// Unpacking the inputs
-			a_m <= input_a[N-M-2:0];
-			a_e <= input_a[N-2:N-M-1];
-			a_s <= input_a[N-1];
-			
-			b_m <= input_b[N-M-2:0];
-			b_e <= input_b[N-2:N-M-1];
-			b_s <= input_b[N-1];
-			
-			a_m1 <= a_m;
-			a_s1 <= a_s;
-			b_m1 <= b_m;
-			b_s1 <= b_s;
-			
-			z_e1 <= z_e;
-			z_s1 <= z_s;
-			
-			// If input_a has the bigger exponent then flag it with greater and find the absolute difference
-			if (a_e > b_e)
-			begin
-				greater <= 2'b01;
-				abs <= a_e - b_e;
-				z_s <= a_s;
-				z_e <= a_e;
-			end
-			
-			// If input_a has the bigger exponent then flag it with greater and find the absolute difference
-			else if (b_e > a_e)
-			begin
-				greater <= 2'b10;
-				abs <= b_e - a_e;
-				z_s <= b_s;
-				z_e <= b_e;
-			end
-			
-			// If the inputs have equal exponent
-			else
-			begin
-				greater <= 2'b00;
-				abs <= 0;
-				z_e <= a_e;
-				// Assigning the overall sign based on the difference between the mantissa
-				if(a_m > b_m)
-				begin
-					z_s <= a_s;
-				end
-				else if(b_m > a_m)
-				begin
-					z_s <= b_s;
-				end
-				else
-				begin
-					z_s <= 0;
-				end
-			end
-			
-			// Condition for overflow is that it sets the output to the larger input
-			if (abs > N-1-M) // Shifting by N-1-M would give 0
-			begin
-				if (greater == 2'b01)
-				begin
-					z_m <= a_m1; // Input a is larger and is translated to the output
-				end
-				else if (greater == 2'b10)
-				begin
-					z_m <= b_m1; // Input b is larger and is translated to the output
-				end
-				else // Shouldn't happen as abs should be 0 for this to occur
-				begin
-					if (a_m1 >= b_m1)
-					begin
-						z_m <= a_m1; // Equal exponents but a has the larger mantissa
-					end
-					else if (b_m1 > a_m1)
-					begin
-						z_m <= b_m1; // Equal exponents but b has the larger mantissa
-					end
-				end
-			end
-			
-			else
-			begin
-			   // If a has the bigger exponent
-				if (greater == 2'b01)
-				begin
-					// If the signs are the same then add
-					if (a_s1 == b_s1)
-					begin
-						z_m <= a_m1 + (b_m1 >> (abs-1));
-					end
-					// If they are different then subtract
-					else
-					begin
-						z_m <= a_m1 - (b_m1 >> (abs-1));
-					end
-				end
-				// If b has the bigger exponent
-				else if (greater == 2'b10)
-				begin
-				// If the signs are the same then add
-					if (a_s1 == b_s1)
-					begin
-						z_m <= b_m1 + (a_m1 >> (abs-1));
-					end
-					// If they are different then subtract
-					else
-					begin
-						z_m <= b_m1 - (a_m1 >> (abs-1));
-					end
-				end
-				// If the exponents are equal
-				else
-				begin
-					// If the signs are the same then add
-					if (a_s1 == b_s1)
-					begin
-						z_m <= a_m1 + b_m1;
-					end
-					// If the signs are different then subtract
-					else
-					begin
-						// First checking which has the bigger mantissa
-						if (a_m1 > b_m1)
-						begin
-							z_m <= a_m1 - b_m1;
-						end
-						else if (b_m1 > a_m1)
-						begin
-							z_m <= b_m1 - a_m1;
-						end
-						// If the mantissa are the same as well then the result should be 0
-						else
-						begin
-							z_m <= 0;
-						end
-					end
-				end
-			end
-			output_z[N-1] <= z_s1;
-			output_z[N-2:N-1-M] <= z_e1;
-			output_z[N-2-M:0] <= z_m;
-		end
-
-	
-		else
-		begin
-			a_m <= 0;
-			a_e <= 0;
-			a_s <= 0;
-			
-			b_m <= 0;
-			b_e <= 0;
-			b_s <= 0;
-		
-			a_m1 <= 0;
-			b_m1 <= 0;
-			z_e1 <= 0;
-			z_s1 <= 0;
-			
-			z_s <= 0;
-			z_e <= 0;
-			z_s1 <= 0;
-			z_e1 <= 0;
-			z_m <= 0;
-			
-			greater <= 0;
-			abs <= 0;
-			
-			output_z <= 0;
-		end
-	end
+typedef enum logic [1:0]{
+    greater_a,
+    greater_b,
+    equal_ab
+
+} grater_state;
+
+module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
+    localparam K=N-M-1;  // Size of mantissa
+
+    input reg [N-1:0] input_a, input_b;
+    input logic clk, reset;
+    output reg [N-1:0] output_z;
+
+    reg [K-1:0] a_m0, b_m0; // mantissa
+    reg [K-1:0] a_m1, b_m1, z_m2;
+
+    reg [K*2-1:0] z_m1a, z_m1b, z_m1z;  // Double mantissa
+    reg z_m1s;
+
+    reg [M-1:0] a_e0, b_e0; // exponent
+    reg [M-1:0] z_e1, z_e2;
+
+    reg a_s0, b_s0; // sign
+    reg a_s1, b_s1, z_s1, z_s2;
+
+    grater_state greater; // 01 for a, 10 for b, 11 for both and 00 for neither
+    reg [M:0] abs; // For the absolute difference between exponents
+
+    always_comb begin
+        output_z = {z_s2, z_e2, z_m2};
+        z_m1a = {a_m1, {K{1'b0}}};
+        z_m1b = {b_m1, {K{1'b0}}};
+
+        // If a has the bigger exponent
+        if (greater == greater_a)
+            begin
+                // If the signs are the same then add
+                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1a+(z_m1b >> abs - 2);
+                    // If they are different then subtract
+                else {z_m1s, z_m1z} = z_m1a-(z_m1b >> abs - 2);
+            end
+            // If b has the bigger exponent
+        else if (greater == greater_b)
+            begin
+                // If the signs are the same then add
+                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1b+(z_m1a >> abs - 2);
+                    // If they are different then subtract
+                else {z_m1s, z_m1z} = z_m1b-(z_m1a >> abs - 2);
+            end
+            // If the exponents are equal
+        else
+            begin
+                // If the signs are the same then add
+                if (a_s1 == b_s1) {z_m1s, z_m1z} = (z_m1a + z_m1b) >> 1;
+                    // If the signs are different then subtract
+                else
+                    begin
+                        // First checking which has the bigger mantissa
+                        if (a_m1 > b_m1) {z_m1s, z_m1z} = z_m1a-z_m1b;
+                        else if (b_m1 > a_m1) {z_m1s, z_m1z} = z_m1b-z_m1a;
+                            // If the mantissa are the same as well then the result should be 0
+                        else {z_m1s, z_m1z} = 0;
+                    end
+            end
+    end
+
+    always_ff @(posedge clk)
+        begin
+            if (~reset)
+                begin
+                    // Unpacking the inputs
+                    a_m0 <= input_a[K-1:0];
+                    a_e0 <= input_a[N-2:K];
+                    a_s0 <= input_a[N-1];
+
+                    b_m0 <= input_b[K-1:0];
+                    b_e0 <= input_b[N-2:K];
+                    b_s0 <= input_b[N-1];
+
+                    // Second stage
+                    a_m1 <= a_m0;
+                    a_s1 <= a_s0;
+                    b_m1 <= b_m0;
+                    b_s1 <= b_s0;
+
+                    z_e2 <= z_e1;
+                    z_s2 <= z_s1;
+
+                    // If input_a has the bigger exponent then flag it with greater and find the absolute difference
+                    if (a_e0 > b_e0)
+                        begin
+                            greater <= greater_a;
+                            abs <= a_e0-b_e0;
+                            z_s1 <= a_s0;
+                            z_e1 <= a_e0;
+                        end
+
+                        // If input_a has the bigger exponent then flag it with greater and find the absolute difference
+                    else if (b_e0 > a_e0)
+                        begin
+                            greater <= greater_b;
+                            abs <= b_e0-a_e0;
+                            z_s1 <= b_s0;
+                            z_e1 <= b_e0;
+                        end
+
+                        // If the inputs have equal exponent
+                    else
+                        begin
+                            greater <= equal_ab;
+                            abs <= -1;
+                            z_e1 <= a_e0;
+                            // Assigning the overall sign based on the difference between the mantissa
+                            if (a_m0 > b_m0) z_s1 <= a_s0;
+                            else if (b_m0 > a_m0) z_s1 <= b_s0;
+                            else z_s1 <= 0;
+                        end
+
+                    // Condition for overflow is that it sets the output to the larger input
+                    if (abs > K) // Shifting by N-1-M would give 0
+                        begin
+                            z_m2 <= (greater == greater_a) ? a_m1 : b_m1;
+
+                            // Input a is larger and is translated to the output
+                            // if (greater == greater_a) z_m0 <= a_m1;
+
+                            // Input b is larger and is translated to the output
+                            // else if (greater == greater_b) z_m0 <= b_m1;
+
+                            // Shouldn't happen as abs should be 0 for this to occur
+                            // else begin
+                            // 	if (a_m1 >= b_m1) z_m0 <= a_m1; // Equal exponents but a has the larger mantissa
+                            // 	else if (b_m1 > a_m1) z_m0 <= b_m1; // Equal exponents but b has the larger mantissa
+                            // end
+                        end
+
+                    else
+                        begin
+							z_m2 <= z_m1z[K*2-1:K];
+                        end
+                end
+            else
+                begin
+                    a_m0 <= 0;
+                    a_e0 <= 0;
+                    a_s0 <= 0;
+
+                    b_m0 <= 0;
+                    b_e0 <= 0;
+                    b_s0 <= 0;
+
+                    a_m1 <= 0;
+                    b_m1 <= 0;
+                    z_e2 <= 0;
+                    z_s2 <= 0;
+
+                    z_s1 <= 0;
+                    z_e1 <= 0;
+                    z_s2 <= 0;
+                    z_e2 <= 0;
+                    z_m2 <= 0;
+
+                    greater <= equal_ab;
+                    abs <= 0;
+                end
+        end
 endmodule : fp_adder

+ 14 - 1
src/fpu16/fpu16.sv

@@ -3,6 +3,9 @@
 
 
 module fpu16_tb;
+	localparam N=16;
+	localparam M=5;
+
 	reg reset, clk;
 	logic [15:0] input_a, input_b, result_add, result_mult;
 	logic [15:0] expected_add, expected_mult;
@@ -12,12 +15,22 @@ module fpu16_tb;
 	fp_product multiplier1(.input_a(input_a), .input_b(input_b), .output_z(result_mult), .clk(clk), .reset(reset));
 	
 	initial forever #5 clk = ~clk;
-	localparam PIPELINES = 3;
+	localparam PIPELINES = 2;
 
 	reg [15:0] test_mem [29:0][3:0];
 
 	initial $readmemh("scripts/fp16_test.hex", test_mem);
 
+	reg [N-2-M:0] exp0_m;
+	reg [M-1:0] exp0_e;
+	reg exp0_s;
+
+	always_comb begin
+		exp0_m = expected_add[N-M-2:0];
+		exp0_e = expected_add[N-2:N-M-1];
+		exp0_s = expected_add[N-1];
+	end
+
 	initial begin
         static int num_err = 0;
         static int num_tests = $size(test_mem) * 2;