asm_compiler.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #!/usr/bin/python3
  2. import sys
  3. import argparse
  4. from os import path
  5. def decode_bytes(val: str):
  6. try:
  7. if val.endswith('h'):
  8. return [int(val[i:i+2], 16) for i in range(0, len(val)-1, 2)]
  9. if val.startswith('0x'):
  10. return [int(val[i:i+2], 16) for i in range(2, len(val), 2)]
  11. if val.startswith('b'):
  12. val = val.replace('_', '')[1:]
  13. return [int(val[i:i+8], 2) for i in range(0, len(val), 8)]
  14. except ValueError:
  15. raise ValueError(f"Invalid binary '{val}'")
  16. if val.isdigit():
  17. i = int(val)
  18. if i > 255 or i < 0:
  19. raise ValueError(f"Invalid binary '{val}', unsigned int out of bounds")
  20. return [i]
  21. if (val.startswith('+') or val.startswith('-')) and val[1:].isdigit():
  22. i = int(val)
  23. if i > 127 or i < -128:
  24. raise ValueError(f"Invalid binary '{val}', signed int out of bounds")
  25. if i < 0: # convert to unsigned
  26. i += 2 ** 8
  27. return [i]
  28. if len(val) == 3 and ((val[0] == "'" and val[2] == "'") or (val[0] == '"' and val[2] == '"')):
  29. return [ord(val[1])]
  30. raise ValueError(f"Invalid binary '{val}'")
  31. def is_reg(r):
  32. if r.startswith('$'):
  33. r = r[1:]
  34. if r.isnumeric() and 0 <= int(r) <= 3:
  35. return True
  36. elif len(r) == 2 and r[0] == 'r' and r[1] in {'0', '1', '2', '3', 'a', 'b', 'c', 'e'}:
  37. return True
  38. return False
  39. def decode_reg(r):
  40. if r.startswith('$') and r[1:].isnumeric():
  41. r = int(r[1:])
  42. if isinstance(r, int):
  43. if 0 <= r <= 3:
  44. return r
  45. raise ValueError(f"Invalid register value {r}")
  46. rl = r.lower()
  47. if rl.startswith('$'):
  48. rl = rl[1:]
  49. if rl == 'ra' or rl == 'r0':
  50. return 0
  51. if rl == 'rb' or rl == 'r1':
  52. return 1
  53. if rl == 'rc' or rl == 'r2':
  54. return 2
  55. if rl == 're' or rl == 'r3':
  56. return 3
  57. raise ValueError(f"Invalid register name '{r}'")
  58. def assemble(file):
  59. odata = []
  60. afile = open(file, 'r')
  61. failed = False
  62. refs = dict()
  63. for lnum, line in enumerate(afile.readlines()):
  64. lnum += 1 # Line numbers start from 1, not 0
  65. if '//' in line:
  66. line = line[:line.index('//')]
  67. if ':' in line:
  68. rsplit = line.split(':', 1)
  69. ref = rsplit[0]
  70. if not ref.isalnum():
  71. print(f"{file}:{lnum}: Invalid pointer reference '{ref}'")
  72. failed = True
  73. continue
  74. if ref in refs:
  75. if refs[ref][1] is not None:
  76. print(f"{file}:{lnum}: Pointer reference '{ref}' is duplicated with {file}:{refs[ref][0]}")
  77. failed = True
  78. continue
  79. refs[ref] = [lnum, len(odata)]
  80. line = rsplit[1]
  81. line = line.replace('\n', '').replace('\r', '').replace('\t', '')
  82. line = line.strip(' ')
  83. if line == '':
  84. continue
  85. ops = line.split()
  86. instr = ops[0].upper()
  87. rops = 3
  88. if instr == 'CPY' or instr == 'COPY':
  89. iname = 'COPY'
  90. inibb = 0
  91. elif instr == 'ADD':
  92. iname = 'ADD'
  93. inibb = 1
  94. elif instr == 'SUB':
  95. iname = 'SUB'
  96. inibb = 2
  97. elif instr == 'AND':
  98. iname = 'AND'
  99. inibb = 3
  100. elif instr == 'OR':
  101. iname = 'OR'
  102. inibb = 4
  103. elif instr == 'XOR':
  104. iname = 'XOR'
  105. inibb = 5
  106. elif instr == 'MUL':
  107. iname = 'MUL'
  108. inibb = 6
  109. elif instr == 'DIV':
  110. iname = 'DIV'
  111. inibb = 7
  112. elif instr == 'SLL':
  113. iname = 'SLL'
  114. inibb = 9
  115. elif instr == 'SHFR':
  116. iname = 'SHTR'
  117. inibb = 7
  118. ops.append(1)
  119. elif instr == 'ROTR':
  120. iname = 'ROTR'
  121. inibb = 7
  122. ops.append(2)
  123. elif instr == 'LW':
  124. iname = 'LW'
  125. inibb = 8
  126. elif instr == 'SW':
  127. iname = 'SW'
  128. inibb = 9
  129. elif instr == 'JEQ':
  130. iname = 'JEQ'
  131. rops = 4
  132. inibb = 10
  133. elif instr == 'JMP' or instr == 'JUMP':
  134. iname = 'JUMP'
  135. rops = 2
  136. inibb = 11
  137. elif instr == 'PUSH':
  138. iname = 'PUSH'
  139. rops = 2
  140. inibb = 14
  141. elif instr == 'POP':
  142. iname = 'POP'
  143. rops = 2
  144. inibb = 15
  145. else:
  146. if len(ops) == 1:
  147. try:
  148. odata += decode_bytes(ops[0])
  149. continue
  150. except ValueError:
  151. pass
  152. print(f"{file}:{lnum}: Instruction '{ops[0]}' not recognised")
  153. failed = True
  154. continue
  155. if len(ops) != rops:
  156. print(f"{file}:{lnum}: {iname} instruction requires {rops - 1} arguments")
  157. failed = True
  158. continue
  159. try:
  160. if iname == 'JUMP':
  161. odata.append(inibb << 4)
  162. try:
  163. odata += decode_bytes(ops[1])
  164. except ValueError:
  165. if not ops[1].isalnum():
  166. print(f"{file}:{lnum}: Invalid pointer reference '{ops[1]}'")
  167. failed = True
  168. continue
  169. if ops[1] in refs:
  170. odata.append(refs[ops[1]][1])
  171. else:
  172. refs[ops[1]] = [lnum, None]
  173. odata.append(ops[1])
  174. continue
  175. rd = decode_reg(ops[1])
  176. if iname == 'COPY' and not is_reg(ops[2]):
  177. imm = decode_bytes(ops[2])[0]
  178. odata.append((inibb << 4) | (rd << 2) | rd)
  179. odata.append(int(imm))
  180. continue
  181. if iname == 'PUSH' or iname == 'POP':
  182. odata.append((inibb << 4) | (rd << 2) | rd)
  183. continue
  184. rs = decode_reg(ops[2])
  185. if iname == 'COPY' and rd == rs:
  186. print(f"{file}:{lnum}: {iname} cannot copy register to itself")
  187. failed = True
  188. continue
  189. odata.append((inibb << 4) | (rd << 2) | rs)
  190. if iname == 'JEQ':
  191. try:
  192. odata += decode_bytes(ops[3])
  193. except ValueError:
  194. if not ops[3].isalnum():
  195. print(f"{file}:{lnum}: Invalid pointer reference '{ops[3]}'")
  196. failed = True
  197. continue
  198. if ops[3] in refs:
  199. odata.append(refs[ops[3]][1])
  200. else:
  201. refs[ops[3]] = [lnum, None]
  202. odata.append(ops[3])
  203. continue
  204. except ValueError as e:
  205. print(f"{file}:{lnum}: {e}")
  206. failed = True
  207. continue
  208. afile.close()
  209. # Convert jumps
  210. for i, l in enumerate(odata):
  211. if isinstance(l, str):
  212. if refs[l][1] is None:
  213. print(f"{file}:{refs[l][0]}: Pointer reference '{l}' does not exist!")
  214. failed = True
  215. continue
  216. odata[i] = refs[l][1]
  217. return not failed, odata
  218. def readable_size(num, disp_bytes=True):
  219. num = abs(num)
  220. if num < 1024 and disp_bytes:
  221. return "[%3.0fB]" % num
  222. if num < 1024 and not disp_bytes:
  223. return ""
  224. num /= 1024.0
  225. for unit in ['Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
  226. if abs(num) < 1024.0:
  227. return "[%3.1f%sB]" % (num, unit)
  228. num /= 1024.0
  229. return "[%.1f%sB]" % (num, 'Yi')
  230. if __name__ == '__main__':
  231. parser = argparse.ArgumentParser(description='Assembly compiler', add_help=True)
  232. parser.add_argument('file', help='Files to compile')
  233. parser.add_argument('-t', '--output_type', choices=['bin', 'mem', 'binary'], default='mem', help='Output type')
  234. parser.add_argument('-o', '--output', help='Output file')
  235. parser.add_argument('-f', '--force', action='store_true', help='Force override output file')
  236. args = parser.parse_args(sys.argv[1:])
  237. if not path.isfile(args.file):
  238. print(f'No file {args.file}!')
  239. sys.exit(1)
  240. output = args.output
  241. if not output:
  242. opath = path.dirname(args.file)
  243. bname = path.basename(args.file).rsplit('.', 1)[0]
  244. ext = '.out'
  245. if args.output_type == 'mem':
  246. ext = '.mem'
  247. elif args.output_type == 'bin':
  248. ext = '.bin'
  249. output = path.join(opath, bname + ext)
  250. if not args.force and path.isfile(output):
  251. print(f'Output file already exists {output}!')
  252. sys.exit(1)
  253. success, data = assemble(args.file)
  254. if success:
  255. print(f"Saving {args.output_type} data to {output}")
  256. print(f"Program size: {len(data)}B {readable_size(len(data), False)}")
  257. with open(output, 'wb') as of:
  258. if args.output_type == 'binary':
  259. a = '\n'.join([format(i, '08b') for i in data])
  260. of.write(a.encode())
  261. elif args.output_type == 'mem':
  262. a = [format(i, '02x') for i in data]
  263. for i in range(int(len(a) / 8) + 1):
  264. of.write((' '.join(a[i * 8:(i + 1) * 8]) + '\n').encode())
  265. elif args.output_type == 'bin':
  266. of.write(bytes(data))
  267. else:
  268. print(f'Failed to compile {args.file}!')
  269. sys.exit(1)
  270. sys.exit(0)