Bladeren bron

Neural network WIP

Min 4 jaren geleden
bovenliggende
commit
2784ee3b6e
6 gewijzigde bestanden met toevoegingen van 249 en 76 verwijderingen
  1. 50 55
      src/fpu16/fp_adder.sv
  2. 17 0
      src/fpu16/fpu16.sv
  3. 153 0
      src/fpu32p/add32.sv
  4. 20 8
      src/fpu32p/fpu32p.sv
  5. 4 7
      src/neural/layer.sv
  6. 5 6
      src/neural/neural.sv

+ 50 - 55
src/fpu16/fp_adder.sv

@@ -13,56 +13,48 @@ module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
     output reg [N-1:0] output_z;
 
     reg [K-1:0] a_m0, b_m0; // mantissa
-    reg [K-1:0] a_m1, b_m1, z_m2;
+    reg [K-1:0] a_m1, b_m1, z_m2, z_m3;
 
     reg [K*2-1:0] z_m1a, z_m1b, z_m1z;  // Double mantissa
-    reg z_m1s;
+    reg z_m1s, z_m2s;
 
     reg [M-1:0] a_e0, b_e0; // exponent
-    reg [M-1:0] z_e1, z_e2;
+    reg [M-1:0] z_e1, z_e2, z_e3;
 
     reg a_s0, b_s0; // sign
-    reg a_s1, b_s1, z_s1, z_s2;
+    reg a_s1, b_s1, z_s1, z_s2, z_s3;
 
     grater_state greater; // 01 for a, 10 for b, 11 for both and 00 for neither
     reg [M:0] abs; // For the absolute difference between exponents
 
     always_comb begin
-        output_z = {z_s2, z_e2, z_m2};
+        output_z = {z_s3, z_e3, z_m3};
         z_m1a = {a_m1, {K{1'b0}}};
         z_m1b = {b_m1, {K{1'b0}}};
 
-        // If a has the bigger exponent
-        if (greater == greater_a)
-            begin
-                // If the signs are the same then add
-                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1a+(z_m1b >> abs - 2);
-                    // If they are different then subtract
-                else {z_m1s, z_m1z} = z_m1a-(z_m1b >> abs - 2);
+        case (greater)
+            greater_a: begin
+                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1a + (z_m1b >> (abs - 1));
+                else {z_m1s, z_m1z} = z_m1a - (z_m1b >> (abs - 1));
             end
-            // If b has the bigger exponent
-        else if (greater == greater_b)
-            begin
-                // If the signs are the same then add
-                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1b+(z_m1a >> abs - 2);
-                    // If they are different then subtract
-                else {z_m1s, z_m1z} = z_m1b-(z_m1a >> abs - 2);
+            greater_b: begin
+                if (a_s1 == b_s1) {z_m1s, z_m1z} = z_m1b + (z_m1a >> (abs - 1));
+                else {z_m1s, z_m1z} = z_m1b - (z_m1a >> (abs - 1));
             end
-            // If the exponents are equal
-        else
-            begin
+            equal_ab: begin
                 // If the signs are the same then add
                 if (a_s1 == b_s1) {z_m1s, z_m1z} = (z_m1a + z_m1b) >> 1;
+                // If the signs are different then subtract
+                else begin
+                    // First checking which has the bigger mantissa
+                    if (a_m1 > b_m1) {z_m1s, z_m1z} = z_m1a - z_m1b;
                     // If the signs are different then subtract
-                else
-                    begin
-                        // First checking which has the bigger mantissa
-                        if (a_m1 > b_m1) {z_m1s, z_m1z} = z_m1a-z_m1b;
-                        else if (b_m1 > a_m1) {z_m1s, z_m1z} = z_m1b-z_m1a;
-                            // If the mantissa are the same as well then the result should be 0
-                        else {z_m1s, z_m1z} = 0;
-                    end
+                    else if (b_m1 > a_m1) {z_m1s, z_m1z} = z_m1b - z_m1a;
+                    // If the mantissa are the same as well then the result should be 0
+                    else {z_m1s, z_m1z} = 0;
+                end
             end
+        endcase
     end
 
     always_ff @(posedge clk)
@@ -91,18 +83,18 @@ module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
                     if (a_e0 > b_e0)
                         begin
                             greater <= greater_a;
-                            abs <= a_e0 - b_e0 - 1;
+                            abs <= a_e0 - b_e0;
                             z_s1 <= a_s0;
-                            z_e1 <= a_e0;
+                            z_e1 <= a_e0 + 1;
                         end
 
                         // If input_a has the bigger exponent then flag it with greater and find the absolute difference
                     else if (b_e0 > a_e0)
                         begin
                             greater <= greater_b;
-                            abs <= b_e0 - a_e0 - 1;
+                            abs <= b_e0 - a_e0;
                             z_s1 <= b_s0;
-                            z_e1 <= b_e0;
+                            z_e1 <= b_e0 + 1;
                         end
 
                         // If the inputs have equal exponent
@@ -110,7 +102,7 @@ module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
                         begin
                             greater <= equal_ab;
                             abs <= 0;
-                            z_e1 <= a_e0;
+                            z_e1 <= a_e0 + 1;
                             // Assigning the overall sign based on the difference between the mantissa
                             if (a_m0 > b_m0) z_s1 <= a_s0;
                             else if (b_m0 > a_m0) z_s1 <= b_s0;
@@ -120,26 +112,26 @@ module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
                     // Condition for overflow is that it sets the output to the larger input
                     if (abs > K) // Shifting by N-1-M would give 0
                         begin
-                            z_m2 <= (greater == greater_a) ? a_m1 : b_m1;
-
-                            // Input a is larger and is translated to the output
-                            // if (greater == greater_a) z_m0 <= a_m1;
-
-                            // Input b is larger and is translated to the output
-                            // else if (greater == greater_b) z_m0 <= b_m1;
-
-                            // Shouldn't happen as abs should be 0 for this to occur
-                            // else begin
-                            // 	if (a_m1 >= b_m1) z_m0 <= a_m1; // Equal exponents but a has the larger mantissa
-                            // 	else if (b_m1 > a_m1) z_m0 <= b_m1; // Equal exponents but b has the larger mantissa
-                            // end
+                            case (greater)
+                                greater_a: z_m2 <= a_m1;
+                                greater_b: z_m2 <= b_m1;
+                            endcase
                         end
-
                     else
                         begin
-							z_m2 <= z_m1z[K*2-1:K];
+                            z_m2 <= z_m1z[K*2-1:K];
+                            z_m2s <= z_m1s;
                         end
-                end
+
+                    if(z_m2s) begin
+                        z_e3 <= z_e2 + 1;
+                    end else begin
+                        z_e3 <= z_e2;
+                    end
+                    z_m3 <= z_m2;
+                    z_s3 <= z_s2;
+
+                end // end ~reset
             else
                 begin
                     a_m0 <= 0;
@@ -152,14 +144,17 @@ module fp_adder#(parameter N=16, M=5)(input_a, input_b, output_z, clk, reset);
 
                     a_m1 <= 0;
                     b_m1 <= 0;
-                    z_e2 <= 0;
-                    z_s2 <= 0;
-
                     z_s1 <= 0;
                     z_e1 <= 0;
-                    z_s2 <= 0;
+
                     z_e2 <= 0;
+                    z_s2 <= 0;
                     z_m2 <= 0;
+                    z_m2s <= 0;
+
+                    z_s3 <= 0;
+                    z_e3 <= 0;
+                    z_m3 <= 0;
 
                     greater <= equal_ab;
                     abs <= 0;

+ 17 - 0
src/fpu16/fpu16.sv

@@ -61,6 +61,23 @@ module fpu16_tb;
 		expected_add = 0;
 		expected_mult = 0;
 
+		// test some common values
+		input_a = 'h41ac3000;
+		input_b = 'h431f9000;
+		# 30;
+		expected_add = 'h43351600;
+		input_a = 'h42ea6000;
+		input_b = 'h41ea6000;
+		# 30;
+		expected_add = 'h43127c00;
+		input_a = 'h413cc000;
+		input_b = 'h411ba000;
+		# 30;
+		expected_add = 'h41ac3000;
+		input_a = 'bx;
+		input_b = 'bx;
+		# 10;
+
         for (int i=0; i < $size(test_mem)+PIPELINES_MUL; i++) begin
 			if(i >= PIPELINES_ADD) expected_add = test_mem[i-PIPELINES_ADD][2];
 			if(i >= PIPELINES_MUL) expected_mult = test_mem[i-PIPELINES_MUL][3];

+ 153 - 0
src/fpu32p/add32.sv

@@ -0,0 +1,153 @@
+// Adopted from
+// https://github.com/naveethmeeran/floating-point-coprocessor/blob/master/single_precision/single_precision.srcs/sources_1/new/single.v
+
+module adder_32(
+    input [31:0] a, b, input clk, output reg [31:0] out
+);
+    reg [7:0] exponenta, exponentb, outexponent0, outexponent1, outexponent2, outexponent3, outexponent4;
+    reg [22:0] mantissaa0, mantissab0, mantissaa1, mantissab1;
+    reg [22:0] mantissaanew2, mantissabshift2, outmantissaa3, outmantissa4;
+    reg signa0, signb0, signa1, signb1, equal0, equal1, equal2, outsign0, outsign1, outsign2, outsign3, outsign4;
+    reg eop1, shift3, eop2;
+    reg [7:0] d0, d1;
+    reg stickybit2, carry;
+    reg [4:0] count1;
+    wire eopa0, carrya2;
+    wire [22:0] mantissabnew2, outmantissaa2;
+
+    function [4:0] trailingzerodetector;
+        input [22:0] mant;
+        reg [22:0] ma;
+        reg [4:0] co;
+        begin
+            ma = mant;
+            co = 5'b0;
+            repeat (23)
+                begin
+                    if (ma[0] == 1'b0)
+                        begin
+                            co = co+1;
+                            ma = ma >> 1;
+                        end
+                end
+            trailingzerodetector = co;
+        end
+    endfunction : trailingzerodetector
+
+    function [4:0] leadingzerodetector;
+        input [22:0] mant;
+        reg [22:0] ma;
+        reg [4:0] co;
+        begin
+            ma = mant;
+            co = 5'b0;
+            repeat (23)
+                begin
+                    if (ma[22] == 1'b0)
+                        begin
+                            co = co+1;
+                            ma = ma << 1;
+                        end
+                end
+            leadingzerodetector = co;
+        end
+    endfunction : leadingzerodetector
+
+    always_ff @(posedge clk)
+        begin
+            signa0 <= a[31];
+            signb0 <= b[31];
+            if (b[30:23] > a[30:23])
+                begin
+                    equal0 <= 1'b0;
+                    d0 <= b[30:23] - a[30:23];
+                    mantissab0 <= a[22:0];
+                    mantissaa0 <= b[22:0];
+                    outexponent0 <= b[30:23];
+                    outsign0 <= b[31];
+                end
+            else if (a[30:23] > b[30:23])
+                begin
+                    equal0 <= 1'b0;
+                    d0 <= a[30:23] - b[30:23];
+                    mantissaa0 <= a[22:0];
+                    mantissab0 <= b[22:0];
+                    outexponent0 <= a[30:23];
+                    outsign0 <= a[31];
+                end
+            else
+                begin
+                    equal0 <= 1'b1;
+                    d0 <= 8'b0;
+                    if (a[22:0] > b[22:0])
+                        begin
+                            mantissaa0 <= a[22:0];
+                            mantissab0 <= b[22:0];
+                        end
+                    else
+                        begin
+                            mantissaa0 <= b[22:0];
+                            mantissab0 <= a[22:0];
+                        end
+                    outexponent0 <= a[30:23];
+                    outsign0 <= a[31];
+                end
+        end
+
+    always_ff @(posedge clk)
+        begin
+            mantissab1 <= mantissab0;
+            mantissaa1 <= mantissaa0;
+            d1 <= d0;
+            outsign1 <= outsign0;
+            outexponent1 <= outexponent0;
+            equal1 <= equal0;
+            eop1 <= signa0 ^ signb0;
+            count1 <= trailingzerodetector(mantissab0);
+        end
+
+    always_ff @(posedge clk)
+        begin
+            stickybit2 <= (d1 > count1) ? 1 : 0;
+            mantissabshift2 <= {1'b1, mantissab1} >> d1;
+            mantissaanew2 <= mantissaa1;
+            outsign2 <= outsign1;
+            outexponent2 <= outexponent1;
+            equal2 <= equal1;
+            eop2 <= eop1;
+        end
+
+    assign mantissabnew2 = (eop2) ? ~(mantissabshift2 + stickybit2) : mantissabshift2 + stickybit2;
+    assign {carrya2, outmantissaa2} = mantissaanew2 + mantissabnew2 + eop2;
+    assign {carrya2, outmantissaa2} = mantissaanew2 + mantissabnew2 + eop2;
+
+    wire cond2;
+    reg ext3, cond3;
+
+    assign cond2 = eop2 && equal2;
+
+    always_ff @(posedge clk)
+        begin
+            outsign3 <= outsign2;
+            outexponent3 <= outexponent2;
+            outmantissaa3 <= outmantissaa2;
+            ext3 <= equal2 & (carrya2 & (~eop2));
+            cond3 <= cond2;
+            if (cond2)
+                shift3 <= leadingzerodetector(outmantissaa2);
+            else
+                shift3 <= equal2 | (carrya2 & (~eop2));
+        end
+
+    always_ff @(posedge clk)
+        begin
+            outsign4 <= outsign3;
+            outexponent4 <= outexponent3 + shift3;
+            if (cond3)
+                outmantissa4 <= outmantissaa3 << (shift3+1);
+            else
+                outmantissa4 = {ext3, outmantissaa3} >> shift3;
+        end
+
+    assign out = {outsign4, outexponent4, outmantissa4};
+endmodule : adder_32

+ 20 - 8
src/fpu32p/fpu32p.sv

@@ -1,4 +1,5 @@
 `include "mult32.v"
+`include "add32.sv"
 
 
 module fpu32p_tb;
@@ -7,12 +8,18 @@ module fpu32p_tb;
     logic [31:0] input_a, input_b, result_add, result_mult;
     logic [31:0] expected_add, expected_mult;
 
-    fp_adder#(.N(32), .M(8)) adder1(
+    wire [7:0] w_exp_exp, w_res_exp;
+    wire w_exp_sign, w_res_sign;
+    wire [22:0] w_exp_man, w_res_man;
+
+    assign {w_exp_sign, w_exp_exp, w_exp_man} = expected_add;
+    assign {w_res_sign, w_res_exp, w_res_man} = result_add;
+
+    adder_32 adder1(
         .clk(clk),
-        .reset(reset),
-        .input_a(input_a),
-        .input_b(input_b),
-        .output_z(result_add)
+        .a(input_a),
+        .b(input_b),
+        .out(result_add)
     );
 
     mult_32 multiplier1(
@@ -26,7 +33,7 @@ module fpu32p_tb;
     );
 
     initial forever #5 clk = ~clk;
-    localparam PIPELINES_ADD = 2;
+    localparam PIPELINES_ADD = 4;
     localparam PIPELINES_MUL = 12;
 
     reg [31:0] test_mem [29:0][3:0];
@@ -40,12 +47,17 @@ module fpu32p_tb;
         clk = 0;
         reset = 1;
 
-        #15;
+        #5;
         reset = 0;
 
-        expected_add = 0;
+        input_a = 'hbec64dc6;
+        input_b = 'h3ecc3194;
+
+        expected_add = 'h3c3c79c0;
         expected_mult = 0;
 
+        #50;
+
         for (int i=0; i < $size(test_mem)+PIPELINES_MUL; i++) begin
             if(i >= PIPELINES_ADD) expected_add = test_mem[i-PIPELINES_ADD][2];
             if(i >= PIPELINES_MUL) expected_mult = test_mem[i-PIPELINES_MUL][3];

+ 4 - 7
src/neural/layer.sv

@@ -163,7 +163,7 @@ module neuron_network_tb;
     );
 
     /* ******************
-    Synchronious network
+    Pipelined network
     ********************/
     reg [31:0] layer1_s [7:0];
     reg [31:0] layer2_s [7:0];
@@ -175,8 +175,7 @@ module neuron_network_tb;
         .rst(rst),
         .x(x),
         .y(layer1_s),
-        .w(layer1_w),
-        .b(layer1_b)
+        .w(layer1_w), .b(layer1_b)
     );
 
     neuron_layer_p#(.C(3), .K(3)) layer_s2(
@@ -184,8 +183,7 @@ module neuron_network_tb;
         .rst(rst),
         .x(layer1_s),
         .y(layer2_s),
-        .w(layer2_w),
-        .b(layer2_b)
+        .w(layer2_w), .b(layer2_b)
     );
 
     neuron_layer_p#(.C(3), .K(1)) layer_s3(
@@ -193,8 +191,7 @@ module neuron_network_tb;
         .rst(rst),
         .x(layer2_s),
         .y(layer3_s),
-        .w(layer3_w),
-        .b(layer3_b)
+        .w(layer3_w), .b(layer3_b)
     );
 
     hard_sigmoid_p sigmoid_s0(

+ 5 - 6
src/neural/neural.sv

@@ -16,12 +16,11 @@ module neural_adder(clk, rst, x0, x1, y);
     //     .clk(clk),
     //     .reset(rst)
     // );
-    fp_adder#(.N(32), .M(8)) adder1(
+    adder_32 adder0(
         .clk(clk),
-        .reset(rst),
-        .input_a(x0),
-        .input_b(x1),
-        .output_z(y)
+        .a(x0),
+        .b(x1),
+        .out(y)
     );
 
 endmodule : neural_adder
@@ -38,7 +37,7 @@ module neural_mult(clk, rst, x0, x1, y);
     //     .clk(clk),
     //     .reset(rst)
     // );
-    mult_32 multiplier1(
+    mult_32 multiplier0(
         .clk(clk),
         .a(x0),
         .b(x1),