asm_compiler.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. #!/usr/bin/python3
  2. import re
  3. import math
  4. import traceback
  5. from typing import Dict, List
  6. label_re = re.compile(r"^[\w$#@~.?]+$", re.IGNORECASE)
  7. hex_re = re.compile(r"^[0-9a-f]+$", re.IGNORECASE)
  8. bin_re = re.compile(r"^[0-1_]+$", re.IGNORECASE)
  9. oct_re = re.compile(r"^[0-8]+$", re.IGNORECASE)
  10. args_re = re.compile("(?:^|,)(?=[^\"]|(\")?)\"?((?(1)[^\"]*|[^,\"]*))\"?(?=,|$)", re.IGNORECASE)
  11. func_re = re.compile("^([\w$#@~.?]+)\s*([|^<>+\-*/%@]{1,2})\s*([\w$#@~.?]+)$", re.IGNORECASE)
  12. func2_re = re.compile("^([\w$#@~.?]+)\s*\(\s*([\w$#@~.?]+)*\)$", re.IGNORECASE)
  13. brackets_re = re.compile(r"(\((?:\(??[^\(]*?\)))", re.IGNORECASE)
  14. secs_re = re.compile("^([\d]+)x([\d]+)x([\d]+)$", re.IGNORECASE)
  15. funcc_re = re.compile("^([\w$#@~.?]+)\(([\w,]+)\)(.*)", re.IGNORECASE)
  16. MAX_INT_BYTES = 12
  17. def args2operands(args):
  18. operands = ['"' + a[1] + '"' if a[0] == '"' else a[1] for a in args_re.findall(args or '') if a[1]]
  19. return operands
  20. def match(regex, s):
  21. return regex.match(s) is not None
  22. class CompilingError(Exception):
  23. def __init__(self, message):
  24. self.message = message
  25. class InstructionError(Exception):
  26. def __init__(self, message):
  27. self.message = message
  28. class Instruction:
  29. def __init__(self, name: str, opcode: str, operands=0, alias=None):
  30. name = name.strip().lower()
  31. if not name or not name.isalnum():
  32. raise InstructionError(f"Invalid instruction name '{name}'")
  33. self.name = name.strip()
  34. self.alias = alias or []
  35. self.reg_operands = 0
  36. opcode = opcode.replace('_', '')
  37. if len(opcode) == 8:
  38. if opcode[4:6] == '??':
  39. self.reg_operands += 1
  40. if opcode[6:8] == '??':
  41. self.reg_operands += 1
  42. else:
  43. raise CompilingError("Invalid opcode: " + opcode)
  44. self.opcode = int(opcode.replace('?', '0'), 2)
  45. self.imm_operands = operands
  46. self.compiler = None
  47. @property
  48. def length(self):
  49. return self.imm_operands + 1
  50. def __len__(self):
  51. return self.length
  52. def _gen_instr(self, regs):
  53. instr = self.opcode
  54. if len(regs) != self.reg_operands:
  55. raise CompilingError(f"Invalid number of registers: set {len(regs)}, required: {self.reg_operands}")
  56. if len(regs) == 2:
  57. if regs[1] is None:
  58. raise CompilingError(f"Unable to decode register name {regs[1]}")
  59. if regs[0] is None:
  60. raise CompilingError(f"Unable to decode register name {regs[0]}")
  61. instr |= regs[0] << 2 | regs[1]
  62. elif len(regs) == 1:
  63. if regs[0] is None:
  64. raise CompilingError(f"Unable to decode register name {regs[0]}")
  65. instr |= regs[0] << 2
  66. return instr.to_bytes(1, 'little') # Order does not matter with 1 byte
  67. def compile(self, operands, scope):
  68. regs = []
  69. for reg in operands[:self.reg_operands]:
  70. regs.append(self.compiler.decode_reg(reg))
  71. imm = self.compiler.decode_with_labels(operands[self.reg_operands:], scope)
  72. if len(imm) != self.imm_operands:
  73. raise CompilingError(f"Instruction {self.name} has invalid argument size {len(imm)} != {self.imm_operands},"
  74. f" supplied args: 0x{imm.hex()}")
  75. instr = self._gen_instr(regs)
  76. return instr + imm
  77. class Section:
  78. def __init__(self):
  79. self.instr = []
  80. self.data = b''
  81. self.count = 0
  82. self.width = 1
  83. self.length = 1
  84. self.size = 2 ** 8
  85. class Compiler:
  86. def __init__(self, address_size=2, byte_order='little'):
  87. self.instr_db: Dict[str, Instruction] = {}
  88. self.data = []
  89. self.labels = {}
  90. self.macros = {}
  91. self.order = byte_order
  92. self.regnames = {}
  93. self.address_size = address_size
  94. def decode_reg(self, s: str):
  95. s = s.strip()
  96. if s in self.regnames:
  97. return self.regnames[s]
  98. raise CompilingError(f"Unrecognised register name: {s}")
  99. def decode_bytes(self, s: str):
  100. s = s.strip()
  101. typ = ""
  102. # Decimal numbers
  103. if (s.startswith('+') or s.startswith('-')) and s[1:].isnumeric():
  104. typ = 'int'
  105. elif s.isnumeric():
  106. typ = 'uint'
  107. elif s.endswith('d') and s[:-1].isnumeric():
  108. s = s[:-1]
  109. typ = 'uint'
  110. elif s.startswith('0d') and s[2:].isnumeric():
  111. s = s[2:]
  112. typ = 'uint'
  113. # Hexadecimal numbers
  114. elif s.startswith('0') and s.endswith('h') and match(hex_re, s[1:-1]):
  115. s = s[1:-1]
  116. typ = 'hex'
  117. elif (s.startswith('$0') or s.startswith('0x') or s.startswith('$0')) and match(hex_re, s[2:]):
  118. s = s[2:]
  119. typ = 'hex'
  120. # Octal numbers
  121. elif (s.endswith('q') or s.endswith('o')) and match(oct_re, s[:-1]):
  122. s = s[:-1]
  123. typ = 'oct'
  124. elif (s.startswith('0q') or s.startswith('0o')) and match(oct_re, s[2:]):
  125. s = s[2:]
  126. typ = 'oct'
  127. # Binary number
  128. elif (s.endswith('b') or s.endswith('y')) and match(bin_re, s[:-1]):
  129. s = s[:-1].replace('_', '')
  130. typ = 'bin'
  131. elif (s.startswith('0b') or s.startswith('0y')) and match(bin_re, s[2:]):
  132. s = s[2:].replace('_', '')
  133. typ = 'bin'
  134. # ASCII
  135. elif s.startswith("'") and s.endswith("'") and len(s) == 3:
  136. s = ord(s[1:-1]).to_bytes(1, self.order)
  137. typ = 'ascii'
  138. elif (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
  139. s = s[1:-1].encode('utf-8').decode("unicode_escape").encode('utf-8')
  140. typ = 'string'
  141. # Convert with limits
  142. if typ == 'uint':
  143. numb = int(s)
  144. for i in range(1, MAX_INT_BYTES + 1):
  145. if numb < 2 ** (i * 8):
  146. return numb.to_bytes(i, self.order)
  147. elif typ == 'int':
  148. numb = int(s)
  149. for i in range(1, MAX_INT_BYTES + 1):
  150. if -2 ** (i * 7) < numb < 2 ** (i * 7):
  151. return numb.to_bytes(i, self.order)
  152. elif typ == 'hex':
  153. numb = int(s, 16)
  154. return numb.to_bytes(int(len(s) / 2) + len(s) % 2, self.order)
  155. elif typ == 'oct':
  156. numb = int(s, 8)
  157. for i in range(1, 9):
  158. if -2 ** (i * 7) < i < 2 ** (i * 8):
  159. return numb.to_bytes(i, self.order)
  160. elif typ == 'bin':
  161. numb = int(s, 2)
  162. return numb.to_bytes(int(len(s) / 8) + len(s) % 8, self.order)
  163. else:
  164. return s
  165. def _decode_labels(self, arg, scope):
  166. immx = self.decode_bytes(arg)
  167. if isinstance(immx, str):
  168. if immx.startswith('.'):
  169. immx = scope + immx
  170. if immx in self.labels:
  171. return self.labels[immx]
  172. else:
  173. raise CompilingError(f"Unknown label: {immx}")
  174. elif isinstance(immx, bytes):
  175. return immx
  176. def decode_with_labels(self, args, scope):
  177. data = b''
  178. for arg in args:
  179. if isinstance(arg, str):
  180. funcm = func_re.match(arg)
  181. if funcm is not None:
  182. g = funcm.groups()
  183. left = self._decode_labels(g[0], scope)
  184. right = self._decode_labels(g[2], scope)
  185. data += self.proc_func(left, right, g[1])
  186. continue
  187. data += self._decode_labels(arg, scope)
  188. return data
  189. def add_reg(self, name, val):
  190. self.regnames[name] = val
  191. self.regnames['$' + name] = val
  192. def add_instr(self, instr: Instruction):
  193. instr.compiler = self
  194. operands = instr.reg_operands + instr.imm_operands
  195. if instr.name in self.instr_db:
  196. raise InstructionError(f"Instruction {instr.name} operands={operands} duplicate!")
  197. self.instr_db[instr.name] = instr
  198. for alias in instr.alias:
  199. if alias.lower() in self.instr_db:
  200. raise InstructionError(f"Instruction alias {alias} operands={operands} duplicate!")
  201. self.instr_db[alias.lower()] = instr
  202. def proc_func(self, left, right, op):
  203. leftInt = int.from_bytes(left, self.order)
  204. rightInt = int.from_bytes(right, self.order)
  205. if op == '|':
  206. result = leftInt | rightInt
  207. elif op == '^':
  208. result = leftInt ^ rightInt
  209. elif op == '&':
  210. result = leftInt & rightInt
  211. elif op == '<<':
  212. result = leftInt << rightInt
  213. elif op == '>>':
  214. result = leftInt >> rightInt
  215. elif op == '+':
  216. result = leftInt + rightInt
  217. elif op == '-':
  218. result = leftInt - rightInt
  219. elif op == '*':
  220. result = leftInt * rightInt
  221. elif op == '/' or op == '//':
  222. result = leftInt // rightInt
  223. elif op == '%' or op == '%%':
  224. result = leftInt % rightInt
  225. elif op == '@':
  226. return bytes([left[len(left) - rightInt - 1]])
  227. else:
  228. raise CompilingError(f"Invalid function operation {op}")
  229. return result.to_bytes(len(left), self.order)
  230. def __code_compiler(self, file, lnum, line_args, csect, scope, macro):
  231. builtin_cmds = {'db', 'dbe'}
  232. if line_args[0].endswith(':') and label_re.match(line_args[0][:-1]) is not None:
  233. # Must be label
  234. label = line_args[0][:-1]
  235. line_args = line_args[1:]
  236. if label.startswith('.'):
  237. if scope is None:
  238. raise CompilingError(f"No local scope for {label}!")
  239. label = scope + label
  240. elif not macro:
  241. scope = label
  242. if label in self.labels:
  243. raise CompilingError(f"Label {label} duplicate")
  244. self.labels[label] = csect.count.to_bytes(csect.length, self.order)
  245. if len(line_args) == 0:
  246. return scope
  247. elif len(line_args) == 1:
  248. args = None
  249. else:
  250. args = line_args[1]
  251. instr_name = line_args[0].lower()
  252. # Builtin instructions
  253. if instr_name == 'db':
  254. data = self.decode_with_labels(args2operands(args), scope)
  255. if len(data) % csect.width != 0:
  256. fill = csect.width - (len(data) % csect.width)
  257. data += b'\x00' * fill
  258. csect.instr.append(data)
  259. csect.count += len(data) // csect.width
  260. return scope
  261. if instr_name == 'dbe':
  262. try:
  263. fill = int(args[0])
  264. except ValueError:
  265. raise CompilingError(f"Instruction 'dbe' invalid argument, must be a number")
  266. except IndexError:
  267. raise CompilingError(f"Instruction 'dbe' invalid argument count! Must be 1")
  268. if fill % csect.width != 0:
  269. fill += csect.width - (fill % csect.width)
  270. data = b'\x00' * fill
  271. csect.instr.append(data)
  272. csect.count += len(data) // csect.width
  273. return scope
  274. if instr_name in self.macros:
  275. argsp = args2operands(args)
  276. if len(argsp) != self.macros[instr_name][0]:
  277. raise CompilingError(f"Invalid macro argument count!")
  278. self.macros[instr_name][3] += 1 # How many time macro been used (used for macro labels)
  279. mlabel = f'{instr_name}.{self.macros[instr_name][3]}'
  280. for slnum, sline in enumerate(self.macros[instr_name][1]):
  281. slnum += 1
  282. mline = sline.copy()
  283. for i, mline0 in enumerate(mline):
  284. for j in range(len(argsp)):
  285. mline0 = mline0.replace(f'%{j + 1}', argsp[j])
  286. mline[i] = re.sub(r'(%{2})([\w$#~.?]+)', mlabel+r'.\2', mline0)
  287. try:
  288. scope = self.__code_compiler(file, lnum, mline, csect, scope, True)
  289. except CompilingError as e:
  290. print(f"ERROR {file}:{self.macros[instr_name][2] + slnum}: {e.message}")
  291. raise CompilingError(f"Previous error")
  292. return scope
  293. if instr_name not in self.instr_db:
  294. raise CompilingError(f"Instruction '{instr_name}' not recognised!")
  295. instr_obj = self.instr_db[instr_name.lower()]
  296. csect.instr.append((instr_obj, args, lnum, scope))
  297. csect.count += instr_obj.length
  298. return scope
  299. @staticmethod
  300. def __line_generator(code):
  301. for lnum, line in enumerate(code):
  302. lnum += 1
  303. line = line.split(';', 1)[0]
  304. line = re.sub(' +', ' ', line) # replace multiple spaces
  305. line = line.strip()
  306. line_args = [l.strip() for l in line.split(' ', 2)]
  307. # line_args = list(filter(lambda x: len(x) > 0, line_args))
  308. if len(line_args) == 0 or line_args[0] == '':
  309. continue
  310. yield lnum, line_args
  311. def compile_file(self, file):
  312. try:
  313. with open(file, 'r') as f:
  314. data = self.compile(file, f.readlines())
  315. return data
  316. except IOError:
  317. return None
  318. def compile(self, file, code):
  319. failure = False
  320. sections: Dict[str, Section] = {}
  321. csect = None
  322. scope = None
  323. macro = None
  324. for lnum, line_args in self.__line_generator(code):
  325. try:
  326. # Inside macro
  327. if macro is not None:
  328. if line_args[0].lower() == '%endmacro':
  329. macro = None
  330. continue
  331. self.macros[macro][1].append(line_args)
  332. continue
  333. # Section
  334. if line_args[0].lower() == 'section':
  335. if len(line_args) < 2:
  336. raise CompilingError(f"Invalid section arguments!")
  337. section_name = line_args[1].lower()
  338. if section_name not in sections:
  339. s = Section()
  340. if len(line_args) == 3:
  341. m = secs_re.match(line_args[2])
  342. if m is not None:
  343. g = m.groups()
  344. s.width = int(g[0])
  345. s.length = int(g[1])
  346. s.size = int(g[2])
  347. else:
  348. raise CompilingError(f"Invalid section argument: {line_args[2]}")
  349. sections[section_name] = s
  350. csect = sections[section_name]
  351. continue
  352. # Macros
  353. elif line_args[0].lower() == '%define':
  354. if len(line_args) != 3:
  355. raise CompilingError(f"Invalid %define arguments!")
  356. self.labels[line_args[1]] = self.decode_bytes(line_args[2])
  357. continue
  358. elif line_args[0].lower() == '%macro':
  359. if len(line_args) != 3:
  360. raise CompilingError(f"Invalid %macro arguments!")
  361. if line_args[1] in self.macros:
  362. raise CompilingError(f"Macro '{line_args[1]}' already in use")
  363. if not line_args[2].isdigit():
  364. raise CompilingError(f"%macro argument 2 must be a number")
  365. macro = line_args[1].lower()
  366. self.macros[macro] = [int(line_args[2]), [], lnum, 0]
  367. continue
  368. elif line_args[0].lower() == '%include':
  369. if len(line_args) != 2:
  370. raise CompilingError(f"Invalid %include arguments!")
  371. raise CompilingError(f"%include is not implemented yet") # TODO: Complete
  372. continue
  373. if csect is None:
  374. raise CompilingError(f"No section defined!")
  375. scope = self.__code_compiler(file, lnum, line_args, csect, scope, False)
  376. except CompilingError as e:
  377. failure = True
  378. print(f"ERROR {file}:{lnum}: {e.message}")
  379. for section in sections.values():
  380. for instr_tuple in section.instr:
  381. if isinstance(instr_tuple, bytes):
  382. section.data += instr_tuple
  383. continue
  384. instr, args, lnum, scope = instr_tuple
  385. try:
  386. operands = args2operands(args)
  387. section.data += instr.compile(operands, scope)
  388. except CompilingError as e:
  389. failure = True
  390. print(f"ERROR {file}:{lnum}: {e.message}")
  391. if failure:
  392. return None
  393. return {k: (v.width, v.length, v.size, v.data) for k, v in sections.items()}
  394. def decompile(self, binary):
  395. addr = 0
  396. res = []
  397. ibin = iter(binary)
  398. for data in ibin:
  399. norm0 = int(data)
  400. norm1 = norm0 & int('11110011', 2)
  401. norm2 = norm0 & int('11110000', 2)
  402. for instr in self.instr_db.values():
  403. if not ((instr.reg_operands == 0 and norm0 == instr.opcode) or
  404. (instr.reg_operands == 1 and norm1 == instr.opcode) or
  405. (instr.reg_operands == 2 and norm2 == instr.opcode)):
  406. continue
  407. asm = f'{addr:04x}: {instr.name.upper().ljust(6)}'
  408. args = []
  409. raw = format(norm0, '02x')
  410. if instr.reg_operands > 0:
  411. args.append(f'r{(norm0 & 12) >> 2}')
  412. if instr.reg_operands > 1:
  413. args.append(f'r{(norm0 & 3)}')
  414. if instr.imm_operands > 0:
  415. b = '0x'
  416. for i in range(instr.imm_operands):
  417. try:
  418. bi = format(int(next(ibin)), '02x')
  419. except StopIteration:
  420. break
  421. b += bi
  422. raw += bi
  423. addr += 1
  424. args.append(b)
  425. line = asm + ', '.join(args)
  426. tabs = ' ' * (27 - int(len(line)))
  427. res.append(f'{line}{tabs}[{raw}]')
  428. break
  429. addr += 1
  430. return '\n'.join(res)
  431. def convert_to_binary(data):
  432. a = '\n'.join([format(i, '08b') for i in data])
  433. return a.encode()
  434. def convert_to_mem(data, width=1, uhex=False):
  435. x = b''
  436. if uhex:
  437. if width == 2:
  438. for i in range(int(len(data) / 2)):
  439. x += format(data[-(i * 2) - 2], f'02x').upper().encode()
  440. x += format(data[-(i * 2) - 1], f'02x').upper().encode()
  441. else:
  442. for i in range(len(data)):
  443. x += format(data[-i - 1], f'02x').upper().encode()
  444. return x
  445. if width == 2:
  446. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  447. if len(data) % 2 == 1:
  448. datax.append(data[-1] << 8)
  449. else:
  450. datax = data
  451. fa = f'0{math.ceil(math.ceil(math.log2(len(datax))) / 4)}x'
  452. a = [format(d, f'0{width * 2}x') for d in datax]
  453. for i in range(int(len(a) / 8) + 1):
  454. y = a[i * 8:(i + 1) * 8]
  455. if len(y) > 0:
  456. x += (' '.join(y) + ' ' * ((8 - len(y)) * 3) + ' // ' + format((i * 8 - 1) + len(y), fa) + '\n').encode()
  457. return x
  458. def convert_to_mif(data, depth=32, width=1):
  459. x = f'''-- auto-generated memory initialisation file
  460. DEPTH = {math.ceil(depth)};
  461. WIDTH = {width * 8};
  462. ADDRESS_RADIX = HEX;
  463. DATA_RADIX = HEX;
  464. CONTENT
  465. BEGIN
  466. '''.encode()
  467. addr_format = f'0{math.ceil(int(math.log2(len(data))) / 4)}x'
  468. if width == 2:
  469. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  470. if len(data) % 2 == 1:
  471. datax.append(data[-1] << 8)
  472. else:
  473. datax = data
  474. a = [format(i, f'0{width * 2}x') for i in datax]
  475. for i in range(int(len(a * width) / 8) + 1):
  476. y = a[i * 8:(i + 1) * 8]
  477. if len(y) > 0:
  478. x += (format(i * 8, addr_format) + ' : ' + ' '.join(y) + ';\n').encode()
  479. x += b"END;"
  480. return x