asm_compiler.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. #!/usr/bin/python3
  2. import re
  3. import math
  4. import traceback
  5. from typing import Dict, List
  6. label_re = re.compile(r"^[\w$#@~.?]+$", re.IGNORECASE)
  7. hex_re = re.compile(r"^[0-9a-f]+$", re.IGNORECASE)
  8. bin_re = re.compile(r"^[0-1_]+$", re.IGNORECASE)
  9. oct_re = re.compile(r"^[0-8]+$", re.IGNORECASE)
  10. args_re = re.compile("(?:^|,)(?=[^\"]|(\")?)\"?((?(1)[^\"]*|[^,\"]*))\"?(?=,|$)", re.IGNORECASE)
  11. func_re = re.compile("^([\w$#@~.?]+)\s*([|^<>+\-*/%@]{1,2})\s*([\w$#@~.?]+)$", re.IGNORECASE)
  12. secs_re = re.compile("^([\d]+)x([\d]+)x([\d]+)$", re.IGNORECASE)
  13. MAX_INT_BYTES = 12
  14. def args2operands(args):
  15. operands = ['"' + a[1] + '"' if a[0] == '"' else a[1] for a in args_re.findall(args or '') if a[1]]
  16. return operands
  17. def match(regex, s):
  18. return regex.match(s) is not None
  19. class CompilingError(Exception):
  20. def __init__(self, message):
  21. self.message = message
  22. class InstructionError(Exception):
  23. def __init__(self, message):
  24. self.message = message
  25. class Instruction:
  26. def __init__(self, name: str, opcode: str, operands=0, alias=None):
  27. name = name.strip().lower()
  28. if not name or not name.isalnum():
  29. raise InstructionError(f"Invalid instruction name '{name}'")
  30. self.name = name.strip()
  31. self.alias = alias or []
  32. self.reg_operands = 0
  33. opcode = opcode.replace('_', '')
  34. if len(opcode) == 8:
  35. if opcode[4:6] == '??':
  36. self.reg_operands += 1
  37. if opcode[6:8] == '??':
  38. self.reg_operands += 1
  39. else:
  40. raise CompilingError("Invalid opcode: " + opcode)
  41. self.opcode = int(opcode.replace('?', '0'), 2)
  42. self.imm_operands = operands
  43. self.compiler = None
  44. @property
  45. def length(self):
  46. return self.imm_operands + 1
  47. def __len__(self):
  48. return self.length
  49. def _gen_instr(self, regs):
  50. instr = self.opcode
  51. if len(regs) != self.reg_operands:
  52. raise CompilingError(f"Invalid number of registers: set {len(regs)}, required: {self.reg_operands}")
  53. if len(regs) == 2:
  54. if regs[1] is None:
  55. raise CompilingError(f"Unable to decode register name {regs[1]}")
  56. if regs[0] is None:
  57. raise CompilingError(f"Unable to decode register name {regs[0]}")
  58. instr |= regs[0] << 2 | regs[1]
  59. elif len(regs) == 1:
  60. if regs[0] is None:
  61. raise CompilingError(f"Unable to decode register name {regs[0]}")
  62. instr |= regs[0] << 2
  63. return instr.to_bytes(1, 'little') # Order does not matter with 1 byte
  64. def compile(self, operands, scope):
  65. regs = []
  66. for reg in operands[:self.reg_operands]:
  67. regs.append(self.compiler.decode_reg(reg))
  68. imm = self.compiler.decode_with_labels(operands[self.reg_operands:], scope)
  69. if len(imm) != self.imm_operands:
  70. raise CompilingError(f"Instruction {self.name} has invalid argument size {len(imm)} != {self.imm_operands},"
  71. f" supplied args: 0x{imm.hex()}")
  72. instr = self._gen_instr(regs)
  73. return instr + imm
  74. class Section:
  75. def __init__(self):
  76. self.instr = []
  77. self.data = b''
  78. self.count = 0
  79. self.width = 1
  80. self.length = 1
  81. self.size = 2**8
  82. class Compiler:
  83. def __init__(self, address_size=2, byte_order='little'):
  84. self.instr_db: Dict[str, Instruction] = {}
  85. self.data = []
  86. self.labels = {}
  87. self.order = byte_order
  88. self.regnames = {}
  89. self.address_size = address_size
  90. def decode_reg(self, s: str):
  91. s = s.strip()
  92. if s in self.regnames:
  93. return self.regnames[s]
  94. raise CompilingError(f"Unrecognised register name: {s}")
  95. def decode_bytes(self, s: str):
  96. s = s.strip()
  97. typ = ""
  98. # Decimal numbers
  99. if (s.startswith('+') or s.startswith('-')) and s[1:].isnumeric():
  100. typ = 'int'
  101. elif s.isnumeric():
  102. typ = 'uint'
  103. elif s.endswith('d') and s[:-1].isnumeric():
  104. s = s[:-1]
  105. typ = 'uint'
  106. elif s.startswith('0d') and s[2:].isnumeric():
  107. s = s[2:]
  108. typ = 'uint'
  109. # Hexadecimal numbers
  110. elif s.startswith('0') and s.endswith('h') and match(hex_re, s[1:-1]):
  111. s = s[1:-1]
  112. typ = 'hex'
  113. elif (s.startswith('$0') or s.startswith('0x') or s.startswith('$0')) and match(hex_re, s[2:]):
  114. s = s[2:]
  115. typ = 'hex'
  116. # Octal numbers
  117. elif (s.endswith('q') or s.endswith('o')) and match(oct_re, s[:-1]):
  118. s = s[:-1]
  119. typ = 'oct'
  120. elif (s.startswith('0q') or s.startswith('0o')) and match(oct_re, s[2:]):
  121. s = s[2:]
  122. typ = 'oct'
  123. # Binary number
  124. elif (s.endswith('b') or s.endswith('y')) and match(bin_re, s[:-1]):
  125. s = s[:-1].replace('_', '')
  126. typ = 'bin'
  127. elif (s.startswith('0b') or s.startswith('0y')) and match(bin_re, s[2:]):
  128. s = s[2:].replace('_', '')
  129. typ = 'bin'
  130. # ASCII
  131. elif s.startswith("'") and s.endswith("'") and len(s) == 3:
  132. s = ord(s[1:-1]).to_bytes(1, self.order)
  133. typ = 'ascii'
  134. elif (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
  135. s = s[1:-1].encode('utf-8').decode("unicode_escape").encode('utf-8')
  136. typ = 'string'
  137. # Convert with limits
  138. if typ == 'uint':
  139. numb = int(s)
  140. for i in range(1, MAX_INT_BYTES+1):
  141. if numb < 2 ** (i * 8):
  142. return numb.to_bytes(i, self.order)
  143. elif typ == 'int':
  144. numb = int(s)
  145. for i in range(1, MAX_INT_BYTES+1):
  146. if -2 ** (i * 7) < numb < 2 ** (i * 7):
  147. return numb.to_bytes(i, self.order)
  148. elif typ == 'hex':
  149. numb = int(s, 16)
  150. return numb.to_bytes(int(len(s) / 2) + len(s) % 2, self.order)
  151. elif typ == 'oct':
  152. numb = int(s, 8)
  153. for i in range(1, 9):
  154. if -2 ** (i * 7) < i < 2 ** (i * 8):
  155. return numb.to_bytes(i, self.order)
  156. elif typ == 'bin':
  157. numb = int(s, 2)
  158. return numb.to_bytes(int(len(s) / 8) + len(s) % 8, self.order)
  159. else:
  160. return s
  161. def _decode_labels(self, arg, scope):
  162. immx = self.decode_bytes(arg)
  163. if isinstance(immx, str):
  164. if immx.startswith('.'):
  165. immx = scope + immx
  166. if immx in self.labels:
  167. return self.labels[immx]
  168. else:
  169. raise CompilingError(f"Unknown label: {immx}")
  170. elif isinstance(immx, bytes):
  171. return immx
  172. def decode_with_labels(self, args, scope):
  173. data = b''
  174. for arg in args:
  175. if isinstance(arg, str):
  176. funcm = func_re.match(arg)
  177. if funcm is not None:
  178. g = funcm.groups()
  179. left = self._decode_labels(g[0], scope)
  180. right = self._decode_labels(g[2], scope)
  181. data += self.proc_func(left, right, g[1])
  182. continue
  183. data += self._decode_labels(arg, scope)
  184. return data
  185. def add_reg(self, name, val):
  186. self.regnames[name] = val
  187. self.regnames['$' + name] = val
  188. def add_instr(self, instr: Instruction):
  189. instr.compiler = self
  190. operands = instr.reg_operands + instr.imm_operands
  191. if instr.name in self.instr_db:
  192. raise InstructionError(f"Instruction {instr.name} operands={operands} duplicate!")
  193. self.instr_db[instr.name] = instr
  194. for alias in instr.alias:
  195. if alias.lower() in self.instr_db:
  196. raise InstructionError(f"Instruction alias {alias} operands={operands} duplicate!")
  197. self.instr_db[alias.lower()] = instr
  198. def proc_func(self, left, right, op):
  199. leftInt = int.from_bytes(left, self.order)
  200. rightInt = int.from_bytes(right, self.order)
  201. if op == '|':
  202. result = leftInt | rightInt
  203. elif op == '^':
  204. result = leftInt ^ rightInt
  205. elif op == '&':
  206. result = leftInt & rightInt
  207. elif op == '<<':
  208. result = leftInt << rightInt
  209. elif op == '>>':
  210. result = leftInt >> rightInt
  211. elif op == '+':
  212. result = leftInt + rightInt
  213. elif op == '-':
  214. result = leftInt - rightInt
  215. elif op == '*':
  216. result = leftInt * rightInt
  217. elif op == '/' or op == '//':
  218. result = leftInt // rightInt
  219. elif op == '%' or op == '%%':
  220. result = leftInt % rightInt
  221. elif op == '@':
  222. return bytes([left[len(left)-rightInt-1]])
  223. else:
  224. raise CompilingError(f"Invalid function operation {op}")
  225. return result.to_bytes(len(left), self.order)
  226. def compile(self, file, code):
  227. failure = False
  228. sections: Dict[str, Section] = {}
  229. csect = None
  230. scope = None
  231. for lnum, line in enumerate(code):
  232. lnum += 1
  233. line = line.split(';', 1)[0].strip()
  234. try:
  235. line_args = [l.strip() for l in line.split(' ', 2)]
  236. # line_args = list(filter(lambda x: len(x) > 0, line_args))
  237. if len(line_args) == 0 or line_args[0] == '':
  238. continue
  239. # Section
  240. if line_args[0].lower() == 'section':
  241. if len(line_args) < 2:
  242. raise CompilingError(f"Invalid section arguments!")
  243. section_name = line_args[1].lower()
  244. if section_name not in sections:
  245. s = Section()
  246. if len(line_args) == 3:
  247. m = secs_re.match(line_args[2])
  248. if m is not None:
  249. g = m.groups()
  250. s.width = int(g[0])
  251. s.length = int(g[1])
  252. s.size = int(g[2])
  253. else:
  254. raise CompilingError(f"Invalid section argument: {line_args[2]}")
  255. sections[section_name] = s
  256. csect = sections[section_name]
  257. continue
  258. # Macros
  259. elif line_args[0].lower() == '%define':
  260. if len(line_args) != 3:
  261. raise CompilingError(f"Invalid %define arguments!")
  262. self.labels[line_args[1]] = self.decode_bytes(line_args[2])
  263. continue
  264. elif line_args[0].lower() == '%include':
  265. if len(line_args) != 2:
  266. raise CompilingError(f"Invalid %include arguments!")
  267. raise CompilingError(f"%include is not implemented yet") # TODO: Complete
  268. continue
  269. if csect is None:
  270. raise CompilingError(f"No section defined!")
  271. builtin_cmds = {'db', 'dbe'}
  272. if line_args[0].lower() not in self.instr_db and\
  273. line_args[0].lower() not in builtin_cmds: # Must be label
  274. label = line_args[0]
  275. line_args = line_args[1:]
  276. if label.startswith('.'):
  277. if scope is None:
  278. raise CompilingError(f"No local scope for {label}!")
  279. label = scope + label
  280. else:
  281. scope = label
  282. if label in self.labels:
  283. raise CompilingError(f"Label {label} duplicate")
  284. self.labels[label] = csect.count.to_bytes(csect.length, self.order)
  285. if len(line_args) == 0:
  286. continue
  287. elif len(line_args) == 1:
  288. instr_name, args = line_args[0].lower(), None
  289. else:
  290. instr_name, args = line_args[0].lower(), line_args[1]
  291. # Builtin instructions
  292. if instr_name == 'db':
  293. data = self.decode_with_labels(args2operands(args), scope)
  294. if len(data) % csect.width != 0:
  295. fill = csect.width - (len(data) % csect.width)
  296. data += b'\x00' * fill
  297. csect.instr.append(data)
  298. csect.count += int(len(data)/csect.width)
  299. continue
  300. if instr_name == 'dbe':
  301. try:
  302. fill = int(args[0])
  303. except ValueError:
  304. raise CompilingError(f"Instruction 'dbe' invalid argument, must be a number")
  305. except IndexError:
  306. raise CompilingError(f"Instruction 'dbe' invalid argument count! Must be 1")
  307. if fill % csect.width != 0:
  308. fill += csect.width - (fill % csect.width)
  309. data = b'\x00' * fill
  310. csect.instr.append(data)
  311. csect.count += int(len(data)/csect.width)
  312. continue
  313. if instr_name not in self.instr_db:
  314. raise CompilingError(f"Instruction '{instr_name}' not recognised!")
  315. instr_obj = self.instr_db[instr_name.lower()]
  316. csect.instr.append((instr_obj, args, lnum, scope))
  317. csect.count += instr_obj.length
  318. except CompilingError as e:
  319. failure = True
  320. print(f"ERROR {file}:{lnum}: {e.message}")
  321. for section in sections.values():
  322. for instr_tuple in section.instr:
  323. if isinstance(instr_tuple, bytes):
  324. section.data += instr_tuple
  325. continue
  326. instr, args, lnum, scope = instr_tuple
  327. try:
  328. operands = args2operands(args)
  329. section.data += instr.compile(operands, scope)
  330. except CompilingError as e:
  331. failure = True
  332. print(f"ERROR {file}:{lnum}: {e.message}")
  333. if failure:
  334. return None
  335. return {k: (v.width, v.length, v.size, v.data) for k, v in sections.items()}
  336. def decompile(self, binary):
  337. addr = 0
  338. res = []
  339. ibin = iter(binary)
  340. for data in ibin:
  341. norm0 = int(data)
  342. norm1 = norm0 & int('11110011', 2)
  343. norm2 = norm0 & int('11110000', 2)
  344. for instr in self.instr_db.values():
  345. if not ((instr.reg_operands == 0 and norm0 == instr.opcode) or
  346. (instr.reg_operands == 1 and norm1 == instr.opcode) or
  347. (instr.reg_operands == 2 and norm2 == instr.opcode)):
  348. continue
  349. asm = f'{addr:04x}: {instr.name.upper().ljust(6)}'
  350. args = []
  351. raw = format(norm0, '02x')
  352. if instr.reg_operands > 0:
  353. args.append(f'r{(norm0 & 12) >> 2}')
  354. if instr.reg_operands > 1:
  355. args.append(f'r{(norm0 & 3)}')
  356. if instr.imm_operands > 0:
  357. b = '0x'
  358. for i in range(instr.imm_operands):
  359. try:
  360. bi = format(int(next(ibin)), '02x')
  361. except StopIteration:
  362. break
  363. b += bi
  364. raw += bi
  365. addr += 1
  366. args.append(b)
  367. line = asm + ', '.join(args)
  368. tabs = ' ' * (27 - int(len(line)))
  369. res.append(f'{line}{tabs}[{raw}]')
  370. break
  371. addr += 1
  372. return '\n'.join(res)
  373. def convert_to_binary(data):
  374. a = '\n'.join([format(i, '08b') for i in data])
  375. return a.encode()
  376. def convert_to_mem(data, width=1, uhex=False):
  377. x = b''
  378. if uhex:
  379. if width == 2:
  380. for i in range(int(len(data)/2)):
  381. x += format(data[-(i*2) - 2], f'02x').upper().encode()
  382. x += format(data[-(i*2) - 1], f'02x').upper().encode()
  383. else:
  384. for i in range(len(data)):
  385. x += format(data[-i-1], f'02x').upper().encode()
  386. return x
  387. if width == 2:
  388. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  389. if len(data) % 2 == 1:
  390. datax.append(data[-1] << 8)
  391. else:
  392. datax = data
  393. fa = f'0{math.ceil(math.ceil(math.log2(len(datax))) / 4)}x'
  394. a = [format(d, f'0{width*2}x') for d in datax]
  395. for i in range(int(len(a) / 8) + 1):
  396. y = a[i * 8:(i + 1) * 8]
  397. if len(y) > 0:
  398. x += (' '.join(y) + ' ' * ((8 - len(y)) * 3) + ' // ' + format((i * 8 - 1) + len(y), fa) + '\n').encode()
  399. return x
  400. def convert_to_mif(data, depth=32, width=1):
  401. x = f'''-- auto-generated memory initialisation file
  402. DEPTH = {math.ceil(depth)};
  403. WIDTH = {width*8};
  404. ADDRESS_RADIX = HEX;
  405. DATA_RADIX = HEX;
  406. CONTENT
  407. BEGIN
  408. '''.encode()
  409. addr_format = f'0{math.ceil(int(math.log2(len(data))) / 4)}x'
  410. if width == 2:
  411. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  412. if len(data) % 2 == 1:
  413. datax.append(data[-1] << 8)
  414. else:
  415. datax = data
  416. a = [format(i, f'0{width*2}x') for i in datax]
  417. for i in range(int(len(a*width) / 8) + 1):
  418. y = a[i * 8:(i + 1) * 8]
  419. if len(y) > 0:
  420. x += (format(i * 8, addr_format) + ' : ' + ' '.join(y) + ';\n').encode()
  421. x += b"END;"
  422. return x