Sfoglia il codice sorgente

Sturctured cascade

Min 4 anni fa
parent
commit
6f50f75c81
1 ha cambiato i file con 157 aggiunte e 99 eliminazioni
  1. 157 99
      src/neural/comp.sv

+ 157 - 99
src/neural/comp.sv

@@ -2,89 +2,114 @@
 `include "../fpu32/fpu32.sv"
 
 /*
-          ____
-   x0 -->|ADD0|--> y0
-   x1 -->|    |
- bus0 <->|    |<-- ack0
- bus1 <->|____|--> stb0 
-          ____ 
-   x2 -->|ADD1|--> y1
-   x3 -->|    |
- bus2 <->|    |<-- ack1
- bus3 <->|____|--> stb1
+            _____
+     x0 -->|  A  |
+  abus0 <->|  D  |
+           |  D  |--> y0
+           |  E  |<-> abus_y
+     x1 -->|  R  |
+  abus0 <->|_____|
 
 */
 
-module adder4to2#(parameter N=32)(x, clk, rst, y, left, right);
+module abus_adder#(parameter N=32)(x, clk, rst, y, left0, left1, right);
     input logic clk;
     input logic rst;
-    input wire [N-1:0] x [3:0];
-    output logic [N-1:0] y [1:0];
-    abus_io left[3:0];
-    abus_io right[1:0];
+    input wire [N-1:0] x [1:0];
+    output logic [N-1:0] y;
+    abus_io left0;
+    abus_io left1;
+    abus_io right;
 
-    wire out_stb [1:0];
-    assign right.stb = out_stb[0] & out_stb[1];
+    wire left_ack, left_stb;
+    assign left0.ack = left_ack;
+    assign left1.ack = left_ack;
+    assign left_stb = left0.stb & left1.stb;
 
     adder add0 (
         .clk(clk),
         .rst(rst),
         .input_a(x[0]),
         .input_b(x[1]),
-        .input_stb(left0.stb),
-        .input_ack(left0.ack),
-        .output_z(y[0]),
+        .input_stb(left_stb),
+        .input_ack(left_ack),
+        .output_z(y),
         .output_z_ack(right.ack),
-        .output_z_stb(out_stb[0])
+        .output_z_stb(right.stb)
     );
+endmodule : abus_adder
+
+/*
+              _____
+     x[0] ==>|  A  |
+ x_stb[0] -->|  D  |
+ x_ack[0] <--|  D  |==> y
+             |  E  |--> y_stb
+     x[1] ==>|  R  |<-- y_ack
+ x_stb[1] -->|     |
+ x_ack[1] <--|_____|
+
+*/
+module cadder#(parameter N=32)(clk, rst, x, x_ack, x_stb, y, y_ack, y_stb);
+    input logic clk;
+    input logic rst;
+    input wire [N-1:0] x [1:0];
+    output logic [N-1:0] y;
+    output x_ack[1:0];
+    input x_stb[1:0];
+    input y_ack;
+    output y_stb;
+
+    wire left_ack, left_stb;
+    assign x_ack[0] = left_ack;
+    assign x_ack[1] = left_ack;
+    assign left_stb = x_stb[0] & x_stb[1];
 
-    adder add1 (
+    adder add0 (
         .clk(clk),
         .rst(rst),
-        .input_a(x[2]),
-        .input_b(x[3]),
-        .input_stb(left1.stb),
-        .input_ack(left1.ack),
-        .output_z(y[1]),
-        .output_z_ack(right.ack),
-        .output_z_stb(out_stb[1])
+        .input_a(x[0]),
+        .input_b(x[1]),
+        .input_stb(left_stb),
+        .input_ack(left_ack),
+        .output_z(y),
+        .output_z_ack(y_ack),
+        .output_z_stb(y_stb)
     );
+endmodule : cadder
 
-endmodule : adder4to2
-
-
-module adder4to2_tb();
-    logic clk, rst;
-    
-    logic [31:0] x [3:0];
-    logic [31:0] y [1:0];
-    abus_io inputBus();
-    abus_io outputBus();
-    
-    adder4to2 adder_casc(.clk(clk), .rst(rst), .x(x), .y(y), .left(inputBus.right), .right(outputBus.left));    
-    initial forever #5 clk = ~clk;
-    initial begin
-        $display("Testing adder4to2");
-        clk = 0;
-        rst = 1;
-        inputBus.stb = 0;
-        outputBus.ack = 0;
-        #20
-        rst = 0;
-        x = {'h41388000, 'h407c0000, 'h42480000, 'h42460000};
-        inputBus.stb = 1;
-        wait(inputBus.ack == 1);
-        #15 inputBus.stb = 0;
-        
-        wait(outputBus.stb == 1);
-        outputBus.ack = 1;
-        assert(y[0] == 'h42c70000);
-        assert(y[1] == 'h41778000);
-        wait(outputBus.stb == 0);
-        outputBus.ack = 0;
-    end
-    
-endmodule : adder4to2_tb
+// module adder4to2_tb();
+//     logic clk, rst;
+//
+//     logic [31:0] x [3:0];
+//     logic [31:0] y [1:0];
+//     abus_io inputBus();
+//     abus_io outputBus();
+//
+//     adder4to2 adder_casc(.clk(clk), .rst(rst), .x(x), .y(y), .left(inputBus.right), .right(outputBus.left));
+//     initial forever #5 clk = ~clk;
+//     initial begin
+//         $display("Testing adder4to2");
+//         clk = 0;
+//         rst = 1;
+//         inputBus.stb = 0;
+//         outputBus.ack = 0;
+//         #20
+//         rst = 0;
+//         x = {'h41388000, 'h407c0000, 'h42480000, 'h42460000};
+//         inputBus.stb = 1;
+//         wait(inputBus.ack == 1);
+//         #15 inputBus.stb = 0;
+//
+//         wait(outputBus.stb == 1);
+//         outputBus.ack = 1;
+//         assert(y[0] == 'h42c70000);
+//         assert(y[1] == 'h41778000);
+//         wait(outputBus.stb == 0);
+//         outputBus.ack = 0;
+//     end
+//
+// endmodule : adder4to2_tb
 
 /*
   K layers of cascade adder
@@ -108,11 +133,6 @@ IN | K3 |  K2  |  K1  | OUT
 [inputs]
 x size: 2**K
 left io size: 2**K
-
-[internal]
-layer connecting wires: 2**K - 2
-number of io buses: 2**(K-1) - 1
-adder4to2 modules: 2**(K-2)
 */
 
 module adder_casc#(parameter K,N=32)(clk, rst, x, y, left, right);
@@ -123,52 +143,89 @@ module adder_casc#(parameter K,N=32)(clk, rst, x, y, left, right);
     
     abus_io right;
     abus_io left[2**K-1:0];
-    
+
     wire [N-1:0] layer_w [2**K-3:0];
-    abus_io bus_w[2**(K-1)-2:0]();
-    
+    wire ack_w [2**K-3:0];
+    wire stb_w [2**K-3:0];
+
     genvar i,j;
     generate
-        for(i=0; i<K; i++) begin : generate_layers    
+        for(i=0; i<K; i++) begin : generate_layers
             // First layers
             if(i == 0) begin
-                for(j=0; j<2**(K-2); j++) begin : generate_casc0
-                    adder4to2 a(
+                for(j=0; j<2**(K-1); j++) begin : generate_casc0
+                    cadder a(
                       .clk(clk),
                       .rst(rst),
-                      .x(x[j*4+:4]),
+                      .x(x[j*2+:2]),
                       .y(layer_w[j]),
-                      .left0(left[j*2].right),
-                      .left1(left[j*2+1].right),
-                      .right(bus_w[j].left)
+                      .x_ack({left[j*2].ack, left[j*2+1].ack}),
+                      .x_stb({left[j*2].stb, left[j*2+1].stb}),
+                      .y_ack(ack_w[j]),
+                      .y_stb(stb_w[j])
+                      // .left0(left[j*2].right),
+                      // .left1(left[j*2+1].right),
+                      // .right(bus_w[j].left)
                     );
                 end
             end
             // Last layer
             else if((K-i) <= 1) begin
-                adder c(
+                localparam s0 = 2**K-4;
+                localparam s1 = 2**K-3;
+                cadder c(
                     .clk(clk),
                     .rst(rst),
-                    .input_a(layer_w[i-1][0]),
-                    .input_b(layer_w[i-1][1]),
-                    .input_stb(bus_w[i-1].stb),
-                    .input_ack(bus_w[i-1].ack),
-                    .output_z(y),
-                    .output_z_ack(right.ack),
-                    .output_z_stb(right.stb)
-                    );
+                    .x(layer_w[s0+:2]),
+                    .y(y),
+                    .x_ack({ack_w[s0], ack_w[s1]}),
+                    .x_stb({stb_w[s0], stb_w[s1]}),
+                    .y_ack(right.ack),
+                    .y_stb(right.stb)
+                    // .left0(bus_w[s0].right),
+                    // .left1(bus_w[s1].right),
+                    // .right(right)
+                );
             end
             // Middle layers
             else begin
-                for(j=0; j<2**(K-i-2); j++) begin : generate_casc1
-                    adder4to2 b(
-                      .clk(clk),
-                      .rst(rst),
-                      .x(layer_w[i-1][j*4+:4]),
-                      .y(layer_w[i][j*2+:2]),
-                      .left(bus_w[i-1][j].right),
-                      .right(bus_w[i][j].left)
+                for(j=0; j<2**(K-i-1); j++) begin : generate_casc1
+                    localparam s = $floor((2.0**(K-1.0) * (2.0**(i-1)-1.0)/2.0**(i-1))+j);
+                    localparam ix = s*2;
+                    localparam iy = s+2**(K-1);
+
+                    cadder b(
+                        .clk(clk),
+                        .rst(rst),
+                        .x(layer_w[ix+:2]),
+                        .y(layer_w[iy]),
+                        .x_ack(ack_w[ix+:2]),
+                        .x_stb(stb_w[ix+:2]),
+                        .y_ack(ack_w[iy]),
+                        .y_stb(stb_w[iy])
                     );
+
+                    // // localparam m = i - 1;
+                    // // localparam s0 = 2.0**(K-1.0) * (2.0**m-1.0)/2.0**m;
+                    // // localparam s = s0 + j;
+                    // localparam s = $floor((2.0**(K-1.0) * (2.0**(i-1)-1.0)/2.0**(i-1))+j);
+                    // localparam ix = s*2;
+                    // localparam ix1 = s*2+1;
+                    // localparam iy = s+2**(K-1);
+                    // abus_adder b(
+                    //   .clk(clk),
+                    //   .rst(rst),
+                    //   // .x(bus_w[index_x+:2]),
+                    // //   // .y(bus_w[(2**(K-1)*((2**(i-2)-1)/2**(i-2)) + j)+2**(K-1)].right),
+                    // //   // .left0(bus_w[(2**(K-1)*((2**(i-2)-1)/2**(i-2)) + j)*2].right),
+                    // //   // .left1(bus_w[(2**(K-1)*((2**(i-2)-1)/2**(i-2)) + j)*2+1].right),
+                    // //   // .right(bus_w[(2**(K-1)*((2**(i-2)-1)/2**(i-2)) + j)+2**(K-1)].left)
+                    //   .x(layer_w[ix+:2]),
+                    //   .y(layer_w[iy]),
+                    //   .left0(bus_v[ix].right),
+                    //   .left1(bus_v[ix1].right),
+                    //   .right(bus_v[iy].left)
+                    // );
                 end
             end
         end
@@ -179,8 +236,8 @@ endmodule : adder_casc
 module adder_casc_tb();
     logic clk, rst;
     
-    localparam K=3;
-    logic [31:0] x [7:0];
+    localparam K=4;
+    logic [31:0] x [2**K-1:0];
     logic [31:0] y;
     abus_io input_ios[2**K-1:0]();
     abus_io output_io();
@@ -208,7 +265,8 @@ module adder_casc_tb();
         output_io.ack = 0;
         #20
         rst = 0;
-        x = {'h43800000, 'h43000000, 'h42800000, 'h42000000, 'h41800000, 'h41000000, 'h40800000, 'h40000000};
+        // Initialise with floating point 2**i
+        foreach(x[i]) x[i] = ('h400 + (i*8)) << 20;
         fork
             foreach(input_vios[i]) begin
                 fork