ソースを参照

Implemented hard sigmoid

Min 4 年 前
コミット
250682bf91
5 ファイル変更269 行追加0 行削除
  1. 43 0
      scripts/sigmoid_test.py
  2. 23 0
      src/neural/layer.sv
  3. 4 0
      src/neural/neural.sv
  4. 171 0
      src/neural/sigmoid.sv
  5. 28 0
      wave_hard_sigmoid_tb.do

+ 43 - 0
scripts/sigmoid_test.py

@@ -0,0 +1,43 @@
+import numpy as np
+from matplotlib import pyplot as plt
+from fpu_test_gen import reverse_endian, dtype_size
+
+
+def generate(fname, samples, dtype=np.float32):
+    dsize = dtype_size(dtype)
+    numbers = np.linspace(-4, 4, samples, dtype=dtype)
+    data = numbers.tobytes()
+    with open(fname, 'w') as f:
+        for i in range(samples):
+            f.write(f"{reverse_endian(data[i*dsize:i*dsize+dsize]).hex()}  // {numbers[i]:0.6f}\n")
+
+
+def view_result(fname, dtype=np.float32):
+    x_bytes = b''
+    y_bytes = b''
+    timing = []
+    with open(fname, 'r') as f:
+        for line in f.readlines():
+            parts = line.split()
+            x_bytes += reverse_endian(bytes.fromhex(parts[0]))
+            y_bytes += reverse_endian(bytes.fromhex(parts[1]))
+            timing.append(int(parts[2]))
+    x = np.frombuffer(x_bytes, dtype=dtype)
+    y = np.frombuffer(y_bytes, dtype=dtype)
+    t = np.array(timing) / 10
+
+    fig, ax = plt.subplots()
+    plt.title('Digital circuit sigmoid function test')
+    ax2 = ax.twinx()
+    ax.plot(x, y, '.', markersize=0.95)
+    ax2.plot(x, t, '.', markersize=0.95, color='m')
+    ax.set_xlabel('Function input')
+    ax.set_ylabel('Function output')
+    ax2.set_ylabel('Timing in cycles', color='m')
+    plt.grid()
+    plt.show()
+
+
+if __name__ == '__main__':
+    # generate('sigmoid_test.hex', 5001)
+    view_result('sigmoid_result.hex')

+ 23 - 0
src/neural/layer.sv

@@ -0,0 +1,23 @@
+// synopsys translate_off
+`timescale 1 ps / 1 ps
+// synopsys translate_on
+
+/*
+
+    =>
+    =>
+    =>
+    =>
+
+*/
+module neuron_layer#(parameter K, N=32)(clk, rst, x, y, w, b, left, right);
+    localparam M = 2**K;
+    input wire clk, rst;
+    input wire [N-1:0] x [M-1:0];
+    input wire [N-1:0] w [M-1:0];
+    input wire [N-1:0] b;
+    output logic [N-1:0] y;
+    abus left, right;
+
+
+endmodule : neuron_layer

+ 4 - 0
src/neural/neural.sv

@@ -0,0 +1,4 @@
+`include "comp.sv"
+`include "sigmoid.sv"
+`include "neuron.sv"
+`include "layer.sv"

+ 171 - 0
src/neural/sigmoid.sv

@@ -0,0 +1,171 @@
+`include "../fpu32/compare.sv"
+
+typedef enum logic [2:0] {
+    hs_input,
+    hs_compare_0,
+    hs_compare_1,
+    hs_compute_0,
+    hs_compute_1,
+    hs_output
+} hs_stage;
+
+/*
+    Function:
+    y = (x + 2.5) * 0.2
+    minimum = 0
+    maximum = 1
+*/
+module hard_sigmoid #(parameter N=32)(clk, rst, x, y, left, right);
+    input clk, rst;
+    input [N-1:0] x;
+    output logic [N-1:0] y;
+    abus_io left, right;
+
+    logic [N-1:0] value, comp_result;
+    hs_stage stage;
+
+    logic gt_neg;
+    logic lt_pos;
+    logic compute;  // Flag to tell if mult and add compution is needed
+
+    wire join_ack, join_stb;
+    wire [N-1:0] join_value;
+    logic in_stb, out_ack;
+    wire in_ack, out_stb;
+
+    // Multiply by 0.2
+    multiplier mult0(
+        .clk(clk),
+        .rst(rst),
+        .input_a('h3e4ccccd),
+        .input_b(join_value),
+        .input_stb(join_stb),
+        .input_ack(join_ack),
+        .output_z(comp_result),
+        .output_z_ack(out_ack),
+        .output_z_stb(out_stb)
+    );
+
+    // Add +2.5
+    adder add0(
+        .clk(clk),
+        .rst(rst),
+        .input_a('h40200000),
+        .input_b(value),
+        .input_stb(in_stb),
+        .input_ack(in_ack),
+        .output_z(join_value),
+        .output_z_ack(join_ack),
+        .output_z_stb(join_stb)
+    );
+
+    fpu32_gt gt0(value, 'hc0200000, gt_neg); // more then -2.5
+    fpu32_lt lt0(value, 'h40200000, lt_pos); // less then +2.5
+
+    always_ff @(posedge clk) begin
+        case (stage)
+            hs_input: begin
+                left.ack <= 1;
+                if (left.ack && left.stb) begin
+                    value <= x;
+                    left.ack <= 0;
+                    stage <= hs_compare_1;
+                end
+            end
+            hs_compare_1: begin
+                // if less than -2.5 output 0
+                if(~gt_neg) begin
+                    y <= 0;
+                    stage <= hs_output;
+                end else
+                // if in between -2.5 and 2.5
+                if(gt_neg & lt_pos) begin
+                    in_stb <= 1;
+                    stage <= hs_compute_0;
+                end else
+                // if more than 2.5 ouput 1
+                begin
+                    y <= 'h3f800000;
+                    stage <= hs_output;
+                end
+            end
+            hs_compute_0: begin
+                if (in_ack) begin
+                    in_stb <= 0;
+                    stage <= hs_compute_1;
+                end
+            end
+            hs_compute_1: begin
+                out_ack <= 1;
+                if (out_ack && out_stb) begin
+                    y <= comp_result;
+                    out_ack <= 0;
+                    stage <= hs_output;
+                end
+            end
+            hs_output: begin
+                right.stb <= 1;
+                if (right.stb && right.ack) begin
+                    right.stb <= 0;
+                    stage <= hs_input;
+                end
+            end
+        endcase
+
+        if (rst == 1) begin
+            stage <= hs_input;
+            left.ack <= 0;
+            right.stb <= 0;
+            y <= 0;
+        end
+    end
+
+
+endmodule : hard_sigmoid
+
+
+module hard_sigmoid_tb;
+    reg rst, clk;
+    reg [31:0] x;
+    wire [31:0] y;
+    abus_io left();
+    abus_io right();
+
+    hard_sigmoid sigmoid0(clk, rst, x, y, left, right);
+
+    reg [31:0] test_mem [5000:0];
+    initial $readmemh("scripts/sigmoid_test.hex", test_mem);
+
+    initial forever #5 clk = ~clk;
+    initial begin
+        int fd, start, delta;
+        fd = $fopen("scripts/sigmoid_result.hex", "w");
+        if(!fd) $display("Failed to open file! %0d", fd);
+
+        clk = 0;
+        rst = 1;
+        left.stb = 0;
+        right.ack = 0;
+        # 10;
+        rst = 0;
+        for (int i=0; i < $size(test_mem); i++) begin
+            x = test_mem[i];
+            left.stb = 1;
+            wait(left.ack == 1);
+            start = $time;
+            #15;
+            left.stb = 0;
+            wait(right.stb == 1);
+            right.ack = 1;
+            delta = $time - start;
+            #15;
+            right.ack = 0;
+            $fdisplay(fd, "%H %H %d", x, y, delta);
+        end
+        $fclose(fd);
+        $finish();
+    end
+
+
+endmodule : hard_sigmoid_tb
+

+ 28 - 0
wave_hard_sigmoid_tb.do

@@ -0,0 +1,28 @@
+onerror {resume}
+quietly WaveActivateNextPane {} 0
+add wave -noupdate /hard_sigmoid_tb/rst
+add wave -noupdate /hard_sigmoid_tb/clk
+add wave -noupdate -radix float32 /hard_sigmoid_tb/x
+add wave -noupdate -radix float32 /hard_sigmoid_tb/y
+add wave -noupdate /hard_sigmoid_tb/left/stb
+add wave -noupdate /hard_sigmoid_tb/left/ack
+add wave -noupdate /hard_sigmoid_tb/right/stb
+add wave -noupdate /hard_sigmoid_tb/right/ack
+TreeUpdate [SetDefaultTree]
+WaveRestoreCursors {{Cursor 1} {46937 ps} 0}
+quietly wave cursor active 1
+configure wave -namecolwidth 294
+configure wave -valuecolwidth 88
+configure wave -justifyvalue left
+configure wave -signalnamewidth 0
+configure wave -snapdistance 10
+configure wave -datasetprefix 0
+configure wave -rowmargin 4
+configure wave -childrowmargin 2
+configure wave -gridoffset 0
+configure wave -gridperiod 1
+configure wave -griddelta 40
+configure wave -timeline 0
+configure wave -timelineunits ns
+update
+WaveRestoreZoom {1087134 ps} {1087351 ps}