| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- `include "../fpu32/compare.sv"
- typedef enum logic [2:0] {
- hs_input,
- hs_compare_0,
- hs_compare_1,
- hs_compute_0,
- hs_compute_1,
- hs_output
- } hs_stage;
- /*
- Function:
- y = (x + 2.5) * 0.2
- minimum = 0
- maximum = 1
- */
- module hard_sigmoid #(parameter N=32)(clk, rst, x, y, left, right);
- input clk, rst;
- input [N-1:0] x;
- output logic [N-1:0] y;
- abus_io left, right;
- logic [N-1:0] value, comp_result;
- hs_stage stage;
- logic gt_neg;
- logic lt_pos;
- logic compute; // Flag to tell if mult and add compution is needed
- wire join_ack, join_stb;
- wire [N-1:0] join_value;
- logic in_stb, out_ack;
- wire in_ack, out_stb;
- // Multiply by 0.2
- multiplier mult0(
- .clk(clk),
- .rst(rst),
- .input_a('h3e4ccccd),
- .input_b(join_value),
- .input_stb(join_stb),
- .input_ack(join_ack),
- .output_z(comp_result),
- .output_z_ack(out_ack),
- .output_z_stb(out_stb)
- );
- // Add +2.5
- adder add0(
- .clk(clk),
- .rst(rst),
- .input_a('h40200000),
- .input_b(value),
- .input_stb(in_stb),
- .input_ack(in_ack),
- .output_z(join_value),
- .output_z_ack(join_ack),
- .output_z_stb(join_stb)
- );
- fpu32_gt gt0(value, 'hc0200000, gt_neg); // more then -2.5
- fpu32_lt lt0(value, 'h40200000, lt_pos); // less then +2.5
- always_ff @(posedge clk) begin
- case (stage)
- hs_input: begin
- left.ack <= 1;
- if (left.ack && left.stb) begin
- value <= x;
- left.ack <= 0;
- stage <= hs_compare_1;
- end
- end
- hs_compare_1: begin
- // if less than -2.5 output 0
- if(~gt_neg) begin
- y <= 0;
- stage <= hs_output;
- end else
- // if in between -2.5 and 2.5
- if(gt_neg & lt_pos) begin
- in_stb <= 1;
- stage <= hs_compute_0;
- end else
- // if more than 2.5 ouput 1
- begin
- y <= 'h3f800000;
- stage <= hs_output;
- end
- end
- hs_compute_0: begin
- if (in_ack) begin
- in_stb <= 0;
- stage <= hs_compute_1;
- end
- end
- hs_compute_1: begin
- out_ack <= 1;
- if (out_ack && out_stb) begin
- y <= comp_result;
- out_ack <= 0;
- stage <= hs_output;
- end
- end
- hs_output: begin
- right.stb <= 1;
- if (right.stb && right.ack) begin
- right.stb <= 0;
- stage <= hs_input;
- end
- end
- endcase
- if (rst == 1) begin
- stage <= hs_input;
- left.ack <= 0;
- right.stb <= 0;
- y <= 0;
- end
- end
- endmodule : hard_sigmoid
- module hard_sigmoid_tb;
- reg rst, clk;
- reg [31:0] x;
- wire [31:0] y;
- abus_io left();
- abus_io right();
- hard_sigmoid sigmoid0(clk, rst, x, y, left, right);
- reg [31:0] test_mem [5000:0];
- initial $readmemh("scripts/sigmoid_test.hex", test_mem);
- initial forever #5 clk = ~clk;
- initial begin
- int fd, start, delta;
- fd = $fopen("scripts/sigmoid_result.hex", "w");
- if(!fd) $display("Failed to open file! %0d", fd);
- clk = 0;
- rst = 1;
- left.stb = 0;
- right.ack = 0;
- # 10;
- rst = 0;
- for (int i=0; i < $size(test_mem); i++) begin
- x = test_mem[i];
- left.stb = 1;
- wait(left.ack == 1);
- start = $time;
- #15;
- left.stb = 0;
- wait(right.stb == 1);
- right.ack = 1;
- delta = $time - start;
- #15;
- right.ack = 0;
- $fdisplay(fd, "%H %H %d", x, y, delta);
- end
- $fclose(fd);
- $finish();
- end
- endmodule : hard_sigmoid_tb
|