asm_compiler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. #!/usr/bin/python3
  2. import re
  3. import math
  4. import traceback
  5. from typing import Dict, List
  6. label_re = re.compile(r"^[\w$#@~.?]+$", re.IGNORECASE)
  7. hex_re = re.compile(r"^[0-9a-f]+$", re.IGNORECASE)
  8. bin_re = re.compile(r"^[0-1_]+$", re.IGNORECASE)
  9. oct_re = re.compile(r"^[0-8]+$", re.IGNORECASE)
  10. args_re = re.compile("(?:^|,)(?=[^\"]|(\")?)\"?((?(1)[^\"]*|[^,\"]*))\"?(?=,|$)", re.IGNORECASE)
  11. func_re = re.compile("^([\w$#@~.?]+)\s*([|^<>+\-*/%@]{1,2})\s*([\w$#@~.?]+)$", re.IGNORECASE)
  12. secs_re = re.compile("^([\d]+)x([\d]+)x([\d]+)$", re.IGNORECASE)
  13. def args2operands(args):
  14. operands = ['"' + a[1] + '"' if a[0] == '"' else a[1] for a in args_re.findall(args or '') if a[1]]
  15. return operands
  16. def match(regex, s):
  17. return regex.match(s) is not None
  18. class CompilingError(Exception):
  19. def __init__(self, message):
  20. self.message = message
  21. class InstructionError(Exception):
  22. def __init__(self, message):
  23. self.message = message
  24. class Instruction:
  25. def __init__(self, name: str, opcode: str, operands=0, alias=None):
  26. name = name.strip().lower()
  27. if not name or not name.isalnum():
  28. raise InstructionError(f"Invalid instruction name '{name}'")
  29. self.name = name.strip()
  30. self.alias = alias or []
  31. self.reg_operands = 0
  32. opcode = opcode.replace('_', '')
  33. if len(opcode) == 8:
  34. if opcode[4:6] == '??':
  35. self.reg_operands += 1
  36. if opcode[6:8] == '??':
  37. self.reg_operands += 1
  38. else:
  39. raise CompilingError("Invalid opcode: " + opcode)
  40. self.opcode = int(opcode.replace('?', '0'), 2)
  41. self.imm_operands = operands
  42. self.compiler = None
  43. @property
  44. def length(self):
  45. return self.imm_operands + 1
  46. def __len__(self):
  47. return self.length
  48. def _gen_instr(self, regs):
  49. instr = self.opcode
  50. if len(regs) != self.reg_operands:
  51. raise CompilingError(f"Invalid number of registers: set {len(regs)}, required: {self.reg_operands}")
  52. if len(regs) == 2:
  53. if regs[1] is None:
  54. raise CompilingError(f"Unable to decode register name {regs[1]}")
  55. if regs[0] is None:
  56. raise CompilingError(f"Unable to decode register name {regs[0]}")
  57. instr |= regs[1] << 2 | regs[0]
  58. elif len(regs) == 1:
  59. if regs[0] is None:
  60. raise CompilingError(f"Unable to decode register name {regs[0]}")
  61. instr |= int(regs[0]) << 2
  62. return instr.to_bytes(1, 'little') # Order does not matter with 1 byte
  63. def compile(self, operands, scope):
  64. regs = []
  65. for reg in operands[:self.reg_operands]:
  66. regs.append(self.compiler.decode_reg(reg))
  67. imm = self.compiler.decode_with_labels(operands[self.reg_operands:], scope)
  68. if len(imm) != self.imm_operands:
  69. raise CompilingError(f"Instruction {self.name} has invalid argument size {len(imm)} != {self.imm_operands},"
  70. f" supplied args: 0x{imm.hex()}")
  71. instr = self._gen_instr(regs)
  72. return instr + imm
  73. class Section:
  74. def __init__(self):
  75. self.instr = []
  76. self.data = b''
  77. self.count = 0
  78. self.width = 1
  79. self.length = 1
  80. self.size = 2**8
  81. class Compiler:
  82. def __init__(self, address_size=2, byte_order='little'):
  83. self.instr_db: Dict[str, Instruction] = {}
  84. self.data = []
  85. self.labels = {}
  86. self.order = byte_order
  87. self.regnames = {}
  88. self.address_size = address_size
  89. def decode_reg(self, s: str):
  90. s = s.strip()
  91. if s in self.regnames:
  92. return self.regnames[s]
  93. raise CompilingError(f"Unrecognised register name: {s}")
  94. def decode_bytes(self, s: str):
  95. s = s.strip()
  96. typ = ""
  97. # Decimal numbers
  98. if s.isnumeric():
  99. typ = 'int'
  100. elif s.endswith('d') and s[:-1].isnumeric():
  101. s = s[:-1]
  102. typ = 'int'
  103. elif s.startswith('0d') and s[2:].isnumeric():
  104. s = s[2:]
  105. typ = 'int'
  106. # Hexadecimal numbers
  107. elif s.startswith('0') and s.endswith('h') and match(hex_re, s[1:-1]):
  108. s = s[1:-1]
  109. typ = 'hex'
  110. elif (s.startswith('$0') or s.startswith('0x') or s.startswith('$0')) and match(hex_re, s[2:]):
  111. s = s[2:]
  112. typ = 'hex'
  113. # Octal numbers
  114. elif (s.endswith('q') or s.endswith('o')) and match(oct_re, s[:-1]):
  115. s = s[:-1]
  116. typ = 'oct'
  117. elif (s.startswith('0q') or s.startswith('0o')) and match(oct_re, s[2:]):
  118. s = s[2:]
  119. typ = 'oct'
  120. # Binary number
  121. elif (s.endswith('b') or s.endswith('y')) and match(bin_re, s[:-1]):
  122. s = s[:-1].replace('_', '')
  123. typ = 'bin'
  124. elif (s.startswith('0b') or s.startswith('0y')) and match(bin_re, s[2:]):
  125. s = s[2:].replace('_', '')
  126. typ = 'bin'
  127. # ASCII
  128. elif s.startswith("'") and s.endswith("'") and len(s) == 3:
  129. s = ord(s[1:-1]).to_bytes(1, self.order)
  130. typ = 'ascii'
  131. elif (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
  132. s = s[1:-1].encode('utf-8').decode("unicode_escape").encode('utf-8')
  133. typ = 'string'
  134. # Convert with limits
  135. if typ == 'int':
  136. numb = int(s)
  137. for i in range(1, 9):
  138. if -2 ** (i * 7) < i < 2 ** (i * 8):
  139. return numb.to_bytes(i, self.order)
  140. elif typ == 'hex':
  141. numb = int(s, 16)
  142. return numb.to_bytes(int(len(s) / 2) + len(s) % 2, self.order)
  143. elif typ == 'oct':
  144. numb = int(s, 8)
  145. for i in range(1, 9):
  146. if -2 ** (i * 7) < i < 2 ** (i * 8):
  147. return numb.to_bytes(i, self.order)
  148. elif typ == 'bin':
  149. numb = int(s, 2)
  150. return numb.to_bytes(int(len(s) / 8) + len(s) % 8, self.order)
  151. else:
  152. return s
  153. def _decode_labels(self, arg, scope):
  154. immx = self.decode_bytes(arg)
  155. if isinstance(immx, str):
  156. if immx.startswith('.'):
  157. immx = scope + immx
  158. if immx in self.labels:
  159. return self.labels[immx]
  160. else:
  161. raise CompilingError(f"Unknown label: {immx}")
  162. elif isinstance(immx, bytes):
  163. return immx
  164. def decode_with_labels(self, args, scope):
  165. data = b''
  166. for arg in args:
  167. if isinstance(arg, str):
  168. funcm = func_re.match(arg)
  169. if funcm is not None:
  170. g = funcm.groups()
  171. left = self._decode_labels(g[0], scope)
  172. right = self._decode_labels(g[2], scope)
  173. data += self.proc_func(left, right, g[1])
  174. continue
  175. data += self._decode_labels(arg, scope)
  176. return data
  177. def add_reg(self, name, val):
  178. self.regnames[name] = val
  179. self.regnames['$' + name] = val
  180. def add_instr(self, instr: Instruction):
  181. instr.compiler = self
  182. operands = instr.reg_operands + instr.imm_operands
  183. if instr.name in self.instr_db:
  184. raise InstructionError(f"Instruction {instr.name} operands={operands} duplicate!")
  185. self.instr_db[instr.name] = instr
  186. for alias in instr.alias:
  187. if alias.lower() in self.instr_db:
  188. raise InstructionError(f"Instruction alias {alias} operands={operands} duplicate!")
  189. self.instr_db[alias.lower()] = instr
  190. def proc_func(self, left, right, op):
  191. if op == '|':
  192. return left | right
  193. if op == '^':
  194. return left ^ right
  195. if op == '&':
  196. return left & right
  197. if op == '<<':
  198. return left << right
  199. if op == '>>':
  200. return left >> right
  201. if op == '+':
  202. return left + right
  203. if op == '-':
  204. return left - right
  205. if op == '*':
  206. return left * right
  207. if op == '/' or op == '//':
  208. return left / right
  209. if op == '%' or op == '%%':
  210. return left % right
  211. if op == '@':
  212. return bytes([left[len(left)-int.from_bytes(right, byteorder=self.order)-1]])
  213. raise CompilingError(f"Invalid function operation {op}")
  214. def compile(self, file, code):
  215. failure = False
  216. sections: Dict[str, Section] = {}
  217. csect = None
  218. scope = None
  219. for lnum, line in enumerate(code):
  220. lnum += 1
  221. line = line.split(';', 1)[0].strip()
  222. try:
  223. line_args = [l.strip() for l in line.split(' ', 2)]
  224. # line_args = list(filter(lambda x: len(x) > 0, line_args))
  225. if len(line_args) == 0 or line_args[0] == '':
  226. continue
  227. # Section
  228. if line_args[0].lower() == 'section':
  229. if len(line_args) < 2:
  230. raise CompilingError(f"Invalid section arguments!")
  231. section_name = line_args[1].lower()
  232. if section_name not in sections:
  233. s = Section()
  234. if len(line_args) == 3:
  235. m = secs_re.match(line_args[2])
  236. if m is not None:
  237. g = m.groups()
  238. s.width = int(g[0])
  239. s.length = int(g[1])
  240. s.size = int(g[2])
  241. else:
  242. raise CompilingError(f"Invalid section argument: {line_args[2]}")
  243. sections[section_name] = s
  244. csect = sections[section_name]
  245. continue
  246. # Macros
  247. elif line_args[0].lower() == '%define':
  248. if len(line_args) != 3:
  249. raise CompilingError(f"Invalid %define arguments!")
  250. self.labels[line_args[1]] = self.decode_bytes(line_args[2])
  251. continue
  252. if csect is None:
  253. raise CompilingError(f"No section defined!")
  254. builtin_cmds = {'db'}
  255. if line_args[0].lower() not in self.instr_db and\
  256. line_args[0].lower() not in builtin_cmds: # Must be label
  257. label = line_args[0]
  258. line_args = line_args[1:]
  259. if label.startswith('.'):
  260. if scope is None:
  261. raise CompilingError(f"No local scope for {label}!")
  262. label = scope + label
  263. else:
  264. scope = label
  265. if label in self.labels:
  266. raise CompilingError(f"Label {label} duplicate")
  267. self.labels[label] = csect.count.to_bytes(csect.length, self.order)
  268. if len(line_args) == 0:
  269. continue
  270. elif len(line_args) == 1:
  271. instr_name, args = line_args[0].lower(), None
  272. else:
  273. instr_name, args = line_args[0].lower(), line_args[1]
  274. # Builtin instructions
  275. if instr_name == 'db':
  276. data = self.decode_with_labels(args2operands(args), scope)
  277. if len(data) % csect.width != 0:
  278. fill = csect.width - (len(data) % csect.width)
  279. data += b'\x00' * fill
  280. csect.instr.append(data)
  281. csect.count += int(len(data)/csect.width)
  282. continue
  283. if instr_name not in self.instr_db:
  284. raise CompilingError(f"Instruction '{instr_name}' not recognised!")
  285. instr_obj = self.instr_db[instr_name.lower()]
  286. csect.instr.append((instr_obj, args, lnum, scope))
  287. csect.count += instr_obj.length
  288. except CompilingError as e:
  289. failure = True
  290. print(f"ERROR {file}:{lnum}: {e.message}")
  291. for section in sections.values():
  292. for instr_tuple in section.instr:
  293. if isinstance(instr_tuple, bytes):
  294. section.data += instr_tuple
  295. continue
  296. instr, args, lnum, scope = instr_tuple
  297. try:
  298. operands = args2operands(args)
  299. section.data += instr.compile(operands, scope)
  300. except CompilingError as e:
  301. failure = True
  302. print(f"ERROR {file}:{lnum}: {e.message}")
  303. if failure:
  304. return None
  305. return {k: (v.width, v.length, v.size, v.data) for k, v in sections.items()}
  306. def decompile(self, binary):
  307. addr = 0
  308. res = []
  309. ibin = iter(binary)
  310. for data in ibin:
  311. norm0 = int(data)
  312. norm1 = norm0 & int('11110011', 2)
  313. norm2 = norm0 & int('11110000', 2)
  314. for instr in self.instr_db.values():
  315. if not ((instr.reg_operands == 0 and norm0 == instr.opcode) or
  316. (instr.reg_operands == 1 and norm1 == instr.opcode) or
  317. (instr.reg_operands == 2 and norm2 == instr.opcode)):
  318. continue
  319. asm = f'{addr:04x}: {instr.name.upper().ljust(6)}'
  320. args = []
  321. raw = format(norm0, '02x')
  322. if instr.reg_operands > 0:
  323. args.append(f'r{(norm0 & 12) >> 2}')
  324. if instr.reg_operands > 1:
  325. args.append(f'r{(norm0 & 3)}')
  326. if instr.imm_operands > 0:
  327. b = '0x'
  328. for i in range(instr.imm_operands):
  329. try:
  330. bi = format(int(next(ibin)), '02x')
  331. except StopIteration:
  332. break
  333. b += bi
  334. raw += bi
  335. addr += 1
  336. args.append(b)
  337. line = asm + ', '.join(args)
  338. tabs = ' ' * (27 - int(len(line)))
  339. res.append(f'{line}{tabs}[{raw}]')
  340. break
  341. addr += 1
  342. return '\n'.join(res)
  343. def convert_to_binary(data):
  344. a = '\n'.join([format(i, '08b') for i in data])
  345. return a.encode()
  346. def convert_to_mem(data, width=1, uhex=False):
  347. x = b''
  348. if uhex:
  349. if width == 2:
  350. for i in range(int(len(data)/2)):
  351. x += format(data[-(i*2) - 2], f'02x').upper().encode()
  352. x += format(data[-(i*2) - 1], f'02x').upper().encode()
  353. else:
  354. for i in range(len(data)):
  355. x += format(data[-i-1], f'02x').upper().encode()
  356. return x
  357. if width == 2:
  358. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  359. if len(data) % 2 == 1:
  360. datax.append(data[-1] << 8)
  361. else:
  362. datax = data
  363. fa = f'0{math.ceil(math.ceil(math.log2(len(datax))) / 4)}x'
  364. a = [format(d, f'0{width*2}x') for d in datax]
  365. for i in range(int(len(a) / 8) + 1):
  366. y = a[i * 8:(i + 1) * 8]
  367. if len(y) > 0:
  368. x += (' '.join(y) + ' ' * ((8 - len(y)) * 3) + ' // ' + format((i * 8 - 1) + len(y), fa) + '\n').encode()
  369. return x
  370. def convert_to_mif(data, depth=32, width=1):
  371. x = f'''-- auto-generated memory initialisation file
  372. DEPTH = {math.ceil(depth)};
  373. WIDTH = {width*8};
  374. ADDRESS_RADIX = HEX;
  375. DATA_RADIX = HEX;
  376. CONTENT
  377. BEGIN
  378. '''.encode()
  379. addr_format = f'0{math.ceil(int(math.log2(len(data))) / 4)}x'
  380. if width == 2:
  381. datax = [(x << 8) | y for x, y in zip(data[0::2], data[1::2])]
  382. if len(data) % 2 == 1:
  383. datax.append(data[-1] << 8)
  384. else:
  385. datax = data
  386. a = [format(i, f'0{width*2}x') for i in datax]
  387. for i in range(int(len(a*width) / 8) + 1):
  388. y = a[i * 8:(i + 1) * 8]
  389. if len(y) > 0:
  390. x += (format(i * 8, addr_format) + ' : ' + ' '.join(y) + ';\n').encode()
  391. x += b"END;"
  392. return x