|
@@ -0,0 +1,171 @@
|
|
|
|
|
+`include "../fpu32/compare.sv"
|
|
|
|
|
+
|
|
|
|
|
+typedef enum logic [2:0] {
|
|
|
|
|
+ hs_input,
|
|
|
|
|
+ hs_compare_0,
|
|
|
|
|
+ hs_compare_1,
|
|
|
|
|
+ hs_compute_0,
|
|
|
|
|
+ hs_compute_1,
|
|
|
|
|
+ hs_output
|
|
|
|
|
+} hs_stage;
|
|
|
|
|
+
|
|
|
|
|
+/*
|
|
|
|
|
+ Function:
|
|
|
|
|
+ y = (x + 2.5) * 0.2
|
|
|
|
|
+ minimum = 0
|
|
|
|
|
+ maximum = 1
|
|
|
|
|
+*/
|
|
|
|
|
+module hard_sigmoid #(parameter N=32)(clk, rst, x, y, left, right);
|
|
|
|
|
+ input clk, rst;
|
|
|
|
|
+ input [N-1:0] x;
|
|
|
|
|
+ output logic [N-1:0] y;
|
|
|
|
|
+ abus_io left, right;
|
|
|
|
|
+
|
|
|
|
|
+ logic [N-1:0] value, comp_result;
|
|
|
|
|
+ hs_stage stage;
|
|
|
|
|
+
|
|
|
|
|
+ logic gt_neg;
|
|
|
|
|
+ logic lt_pos;
|
|
|
|
|
+ logic compute; // Flag to tell if mult and add compution is needed
|
|
|
|
|
+
|
|
|
|
|
+ wire join_ack, join_stb;
|
|
|
|
|
+ wire [N-1:0] join_value;
|
|
|
|
|
+ logic in_stb, out_ack;
|
|
|
|
|
+ wire in_ack, out_stb;
|
|
|
|
|
+
|
|
|
|
|
+ // Multiply by 0.2
|
|
|
|
|
+ multiplier mult0(
|
|
|
|
|
+ .clk(clk),
|
|
|
|
|
+ .rst(rst),
|
|
|
|
|
+ .input_a('h3e4ccccd),
|
|
|
|
|
+ .input_b(join_value),
|
|
|
|
|
+ .input_stb(join_stb),
|
|
|
|
|
+ .input_ack(join_ack),
|
|
|
|
|
+ .output_z(comp_result),
|
|
|
|
|
+ .output_z_ack(out_ack),
|
|
|
|
|
+ .output_z_stb(out_stb)
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ // Add +2.5
|
|
|
|
|
+ adder add0(
|
|
|
|
|
+ .clk(clk),
|
|
|
|
|
+ .rst(rst),
|
|
|
|
|
+ .input_a('h40200000),
|
|
|
|
|
+ .input_b(value),
|
|
|
|
|
+ .input_stb(in_stb),
|
|
|
|
|
+ .input_ack(in_ack),
|
|
|
|
|
+ .output_z(join_value),
|
|
|
|
|
+ .output_z_ack(join_ack),
|
|
|
|
|
+ .output_z_stb(join_stb)
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ fpu32_gt gt0(value, 'hc0200000, gt_neg); // more then -2.5
|
|
|
|
|
+ fpu32_lt lt0(value, 'h40200000, lt_pos); // less then +2.5
|
|
|
|
|
+
|
|
|
|
|
+ always_ff @(posedge clk) begin
|
|
|
|
|
+ case (stage)
|
|
|
|
|
+ hs_input: begin
|
|
|
|
|
+ left.ack <= 1;
|
|
|
|
|
+ if (left.ack && left.stb) begin
|
|
|
|
|
+ value <= x;
|
|
|
|
|
+ left.ack <= 0;
|
|
|
|
|
+ stage <= hs_compare_1;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+ hs_compare_1: begin
|
|
|
|
|
+ // if less than -2.5 output 0
|
|
|
|
|
+ if(~gt_neg) begin
|
|
|
|
|
+ y <= 0;
|
|
|
|
|
+ stage <= hs_output;
|
|
|
|
|
+ end else
|
|
|
|
|
+ // if in between -2.5 and 2.5
|
|
|
|
|
+ if(gt_neg & lt_pos) begin
|
|
|
|
|
+ in_stb <= 1;
|
|
|
|
|
+ stage <= hs_compute_0;
|
|
|
|
|
+ end else
|
|
|
|
|
+ // if more than 2.5 ouput 1
|
|
|
|
|
+ begin
|
|
|
|
|
+ y <= 'h3f800000;
|
|
|
|
|
+ stage <= hs_output;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+ hs_compute_0: begin
|
|
|
|
|
+ if (in_ack) begin
|
|
|
|
|
+ in_stb <= 0;
|
|
|
|
|
+ stage <= hs_compute_1;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+ hs_compute_1: begin
|
|
|
|
|
+ out_ack <= 1;
|
|
|
|
|
+ if (out_ack && out_stb) begin
|
|
|
|
|
+ y <= comp_result;
|
|
|
|
|
+ out_ack <= 0;
|
|
|
|
|
+ stage <= hs_output;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+ hs_output: begin
|
|
|
|
|
+ right.stb <= 1;
|
|
|
|
|
+ if (right.stb && right.ack) begin
|
|
|
|
|
+ right.stb <= 0;
|
|
|
|
|
+ stage <= hs_input;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+ endcase
|
|
|
|
|
+
|
|
|
|
|
+ if (rst == 1) begin
|
|
|
|
|
+ stage <= hs_input;
|
|
|
|
|
+ left.ack <= 0;
|
|
|
|
|
+ right.stb <= 0;
|
|
|
|
|
+ y <= 0;
|
|
|
|
|
+ end
|
|
|
|
|
+ end
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+endmodule : hard_sigmoid
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+module hard_sigmoid_tb;
|
|
|
|
|
+ reg rst, clk;
|
|
|
|
|
+ reg [31:0] x;
|
|
|
|
|
+ wire [31:0] y;
|
|
|
|
|
+ abus_io left();
|
|
|
|
|
+ abus_io right();
|
|
|
|
|
+
|
|
|
|
|
+ hard_sigmoid sigmoid0(clk, rst, x, y, left, right);
|
|
|
|
|
+
|
|
|
|
|
+ reg [31:0] test_mem [5000:0];
|
|
|
|
|
+ initial $readmemh("scripts/sigmoid_test.hex", test_mem);
|
|
|
|
|
+
|
|
|
|
|
+ initial forever #5 clk = ~clk;
|
|
|
|
|
+ initial begin
|
|
|
|
|
+ int fd, start, delta;
|
|
|
|
|
+ fd = $fopen("scripts/sigmoid_result.hex", "w");
|
|
|
|
|
+ if(!fd) $display("Failed to open file! %0d", fd);
|
|
|
|
|
+
|
|
|
|
|
+ clk = 0;
|
|
|
|
|
+ rst = 1;
|
|
|
|
|
+ left.stb = 0;
|
|
|
|
|
+ right.ack = 0;
|
|
|
|
|
+ # 10;
|
|
|
|
|
+ rst = 0;
|
|
|
|
|
+ for (int i=0; i < $size(test_mem); i++) begin
|
|
|
|
|
+ x = test_mem[i];
|
|
|
|
|
+ left.stb = 1;
|
|
|
|
|
+ wait(left.ack == 1);
|
|
|
|
|
+ start = $time;
|
|
|
|
|
+ #15;
|
|
|
|
|
+ left.stb = 0;
|
|
|
|
|
+ wait(right.stb == 1);
|
|
|
|
|
+ right.ack = 1;
|
|
|
|
|
+ delta = $time - start;
|
|
|
|
|
+ #15;
|
|
|
|
|
+ right.ack = 0;
|
|
|
|
|
+ $fdisplay(fd, "%H %H %d", x, y, delta);
|
|
|
|
|
+ end
|
|
|
|
|
+ $fclose(fd);
|
|
|
|
|
+ $finish();
|
|
|
|
|
+ end
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+endmodule : hard_sigmoid_tb
|
|
|
|
|
+
|