FPA_module_test.sv 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. module floating_add #(parameter N=16, M=4)(input_1, input_2, sum, diff);
  2. input logic [N-1:0] input_1, input_2;
  3. output logic [N-1:0] sum;
  4. output logic [M:0] diff;
  5. logic flag_a;
  6. logic flag_b;
  7. logic [M:0] abs;
  8. logic [N-3-M:0] res;
  9. // sign_x = x[N-1]
  10. // exponent_x = x[N-2:N-2-M]
  11. // mantissa_x = x[N-3-M:0]
  12. always_comb
  13. begin
  14. if (input_1[N-2:N-2-M] > input_2[N-2:N-2-M]) // If input 1 has the bigger exponent
  15. begin
  16. // Flags input a as larger and calculates the absolute difference
  17. flag_a = 1;
  18. flag_b = 0;
  19. abs = input_1[N-2:N-2-M] - input_2[N-2:N-2-M];
  20. // ASsigning overall sign of the output
  21. sum[N-1] = input_1[N-1];
  22. // Sets output to have the same exponent
  23. sum[N-2:N-2-M] = input_1[N-2:N-2-M];
  24. end
  25. else if (input_2[N-2:N-2-M] > input_1[N-2:N-2-M]) // If input 2 has the bigger exponent
  26. begin
  27. // Similarly flags input b as larger and calculates the absolute difference
  28. flag_a = 0;
  29. flag_b = 1;
  30. abs = input_2[N-2:N-2-M] - input_1[N-2:N-2-M];
  31. // ASsigning overall sign of the output
  32. sum[N-1] = input_2[N-1];
  33. // Sets ouput to have the same exponent
  34. sum[N-2:N-2-M] = input_2[N-2:N-2-M];
  35. end
  36. else
  37. begin
  38. // THe condition that both inputs have the same exponent
  39. flag_a = 1;
  40. flag_b = 1;
  41. abs = 0;
  42. // ASsigning overall sign of the output based on size of the mantissa
  43. if (input_1[N-3-M:0] >= input_2[N-3-M:0]) sum[N-1] = input_1[N-1];
  44. else sum[N-1] = input_2[N-1];
  45. sum[N-2:N-2-M] = input_1[N-2:N-2-M];
  46. end
  47. diff = abs;
  48. end
  49. //Second pipeline stage 1
  50. pipe pipe1(.clk(clk), .reset(reset), .D(Q0), .Q(Q1));
  51. always_comb
  52. begin
  53. // Condition for overflow is that it sets the output to the larger input
  54. if (abs > 9) // Because size of mantissa is 10 bits and shifting by 10 would give 0
  55. begin
  56. if (flag_a & ~flag_b) sum = input_1; // input 1 is larger and is translated to output
  57. else if (~flag_a & flag_b) sum = input_2; // input 2 is larger and is translated to output
  58. else // exponents are the same
  59. begin
  60. if (input_1[N-3-M:0] >= input_2[N-3-M:0]) sum = input_1;// input 1 has the bigger mantissa
  61. else sum = input_2; // input 2 has the bigger mantissa
  62. end
  63. end
  64. else
  65. begin
  66. // Shifts the smaller input's mantissa to the right based on abs
  67. if (flag_a & ~flag_b)// If input 1 has the larger exponent
  68. begin
  69. // If the signs of both inputs are the same you add, otherwise you subtract
  70. if (input_1[N-1] == input_2[N-1])
  71. begin
  72. res = input_1[N-3-M:0] + (input_2[N-3-M:0] >> abs-1); // Sum the mantissa
  73. sum[N-3-M:0] = res;
  74. end
  75. else
  76. begin
  77. res = input_1[N-3-M:0] - (input_2[N-3-M:0] >> abs-1); // Subtract the mantissas
  78. sum[N-3-M:0] = res;
  79. end
  80. end
  81. else if (~flag_a & flag_b)
  82. begin
  83. // If the signs of both inputs are the same you add, otherwise you subtract
  84. if (input_1[N-1] == input_2[N-1])
  85. begin
  86. res = (input_1[N-3-M:0] >> abs-1) + input_2[N-3-M:0]; // Sum the mantissa
  87. sum[N-3-M:0] = res;
  88. end
  89. else
  90. begin
  91. res = input_2[N-3-M:0] - (input_1[N-3-M:0] >> abs-1); // Subtract the mantissas
  92. sum[N-3-M:0] = res;
  93. end
  94. end
  95. else
  96. begin
  97. if (input_1[N-1] == input_2[N-1]) // If exponents and signs equal
  98. begin
  99. res = input_1[N-3-M:0] + input_2[N-3-M:0]; // Sum the mantissa
  100. sum[N-3-M:0] = res;
  101. end
  102. else // In this case it will be a subtraction
  103. begin
  104. if (input_1[N-3-M:0] > input_2[N-3-M:0]) // Which has the larger mantissa
  105. begin
  106. res = input_1[N-3-M:0] - input_2[N-3-M:0]; // Subtract the mantissa
  107. sum[N-3-M:0] = res;
  108. end
  109. else if (input_1[N-3-M:0] < input_2[N-3-M:0])
  110. begin
  111. res = input_2[N-3-M:0] - input_1[N-3-M:0]; // Subtract the mantissa
  112. sum[N-3-M:0] = res;
  113. end
  114. else res = 0; // Both the exponent and the mantissa are equal so subtraction leads to 0
  115. sum[N-3-M:0] = res;
  116. end
  117. end
  118. end
  119. end
  120. endmodule : floating_add
  121. module floating_product #(parameter N=16, M=4)(input_1, input_2, product);
  122. input logic [N-1:0] input_1, input_2;
  123. output logic [N-1:0] product;
  124. // sign_x = x[N-1]
  125. // exponent_x = x[N-2:N-2-M]
  126. // mantissa_x = x[N-3-M:0]
  127. logic [N-2:N-2-M] sum;
  128. logic [2*(N-3-M):0] mult;
  129. // We have assigned an {M+1} bit exponent so we must have a 2^{M} offset
  130. assign sum = input_1[N-2:N-2-M] + input_2[N-2:N-2-M];
  131. assign product[N-2:N-2-M] = sum - (1'b1 << M) + 2;
  132. always_comb
  133. begin
  134. // Setting the mantissa of the output
  135. mult = input_1[N-3-M:0] * input_2[N-3-M:0];
  136. if (mult[N-3-M]) product[N-3-M:0] = mult[2*(N-3-M):2*(N-3-M)-9];
  137. else product[N-3-M:0] = mult[2*(N-3-M):2*(N-3-M)-9] << 1;
  138. product[N-1] = input_1[N-1] ^ input_2[N-1];
  139. end
  140. endmodule : floating_product
  141. module pipe #(parameter N=16)(clk, reset, Q, D);
  142. input logic clk, reset;
  143. input logic [N-1:0] D;
  144. output reg [N-1:0] Q;
  145. reg [N-1:0] in_pipe;
  146. always @(posedge clk or negedge reset)
  147. begin
  148. if(reset) in_pipe = 0;
  149. else in_pipe = D;
  150. end
  151. always @(posedge clk or negedge reset)
  152. begin
  153. if(reset) Q = 0;
  154. else Q = in_pipe;
  155. end
  156. endmodule : pipe
  157. module floating_tb;
  158. reg reset, clk;
  159. logic [15:0] input_a, input_b, result_add, result_mult;
  160. logic [4:0] diff;
  161. floating_add adder1(.input_1(input_a), .input_2(input_b), .sum(result_add), .diff(diff));
  162. floating_product multiplier1(.input_1(input_a), .input_2(input_b), .product(result_mult));
  163. reg [15:0] test_mem [29:0][3:0];
  164. initial $readmemh("../../scripts/fp16_test.hex", test_mem);
  165. initial begin
  166. static int num_err = 0;
  167. static int num_tests = $size(test_mem) * 2;
  168. for (int i=0; i < $size(test_mem); i++) begin
  169. input_a = test_mem[i][0];
  170. input_b = test_mem[i][1];
  171. #10;
  172. if(result_add != test_mem[i][2]) begin
  173. if(num_err < 20)
  174. $display("FAIL ADD: %H + %H = %H, expected %H", input_a, input_b, result_add, test_mem[i][2]);
  175. num_err = num_err + 1;
  176. end
  177. if(result_mult != test_mem[i][3]) begin
  178. if(num_err < 20)
  179. $display("FAIL MULTIPLY: %H + %H = %H, expected %H", input_a, input_b, result_mult, test_mem[i][3]);
  180. num_err = num_err + 1;
  181. end
  182. end
  183. $display("Passed %d of %d tests", num_tests-num_err, num_tests);
  184. $finish();
  185. end
  186. endmodule : floating_tb
  187. module floating32_tb;
  188. reg reset, clk;
  189. logic [31:0] input_a, input_b, result_add, result_mult;
  190. floating_add#(.N(32), .M(8)) add0(
  191. .input_1(input_a), .input_2(input_b), .sum(result_add), .diff()
  192. );
  193. floating_product#(.N(32), .M(8)) mult0(
  194. .input_1(input_a), .input_2(input_b), .product(result_mult)
  195. );
  196. reg [31:0] test_mem [29:0][3:0];
  197. initial $readmemh("scripts/fp32_test.hex", test_mem);
  198. initial begin
  199. static int num_err = 0;
  200. static int num_tests = $size(test_mem) * 2;
  201. for (int i=0; i < $size(test_mem); i++) begin
  202. input_a = test_mem[i][0];
  203. input_b = test_mem[i][1];
  204. #10;
  205. if(result_add != test_mem[i][2]) begin
  206. if(num_err < 20)
  207. $display("FAIL ADD: %H + %H = %H, expected %H", input_a, input_b, result_add, test_mem[i][2]);
  208. num_err = num_err + 1;
  209. end
  210. if(result_mult != test_mem[i][3]) begin
  211. if(num_err < 20)
  212. $display("FAIL MULTIPLY: %H + %H = %H, expected %H", input_a, input_b, result_mult, test_mem[i][3]);
  213. num_err = num_err + 1;
  214. end
  215. end
  216. $display("Passed %d of %d tests", num_tests-num_err, num_tests);
  217. $finish();
  218. end
  219. endmodule : floating32_tb