asm_compiler.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. #!/usr/bin/python3
  2. import re
  3. import math
  4. import traceback
  5. label_re = re.compile(r"^[\w\$\#\@\~\.\?]+$", re.IGNORECASE)
  6. hex_re = re.compile(r"^[0-9a-f]+$", re.IGNORECASE)
  7. bin_re = re.compile(r"^[0-1_]+$", re.IGNORECASE)
  8. oct_re = re.compile(r"^[0-8]+$", re.IGNORECASE)
  9. def match(regex, s):
  10. return regex.match(s) is not None
  11. def decode_bytes(val: str):
  12. try:
  13. if val.endswith('h'):
  14. return [int(val[i:i + 2], 16) for i in range(0, len(val) - 1, 2)]
  15. if val.startswith('0x'):
  16. return [int(val[i:i + 2], 16) for i in range(2, len(val), 2)]
  17. if val.startswith('b'):
  18. val = val.replace('_', '')[1:]
  19. return [int(val[i:i + 8], 2) for i in range(0, len(val), 8)]
  20. except ValueError:
  21. raise ValueError(f"Invalid binary '{val}'")
  22. if val.isdigit():
  23. i = int(val)
  24. if i > 255 or i < 0:
  25. raise ValueError(f"Invalid binary '{val}', unsigned int out of bounds")
  26. return [i]
  27. if (val.startswith('+') or val.startswith('-')) and val[1:].isdigit():
  28. i = int(val)
  29. if i > 127 or i < -128:
  30. raise ValueError(f"Invalid binary '{val}', signed int out of bounds")
  31. if i < 0: # convert to unsigned
  32. i += 2 ** 8
  33. return [i]
  34. if len(val) == 3 and ((val[0] == "'" and val[2] == "'") or (val[0] == '"' and val[2] == '"')):
  35. return [ord(val[1])]
  36. raise ValueError(f"Invalid binary '{val}'")
  37. def is_reg(r):
  38. if r.startswith('$'):
  39. r = r[1:]
  40. if r.isnumeric() and 0 <= int(r) <= 3:
  41. return True
  42. elif len(r) == 2 and r[0] == 'r' and r[1] in {'0', '1', '2', '3', 'a', 'b', 'c', 'e'}:
  43. return True
  44. return False
  45. def decode_reg(r):
  46. if r.startswith('$') and r[1:].isnumeric():
  47. r = int(r[1:])
  48. if isinstance(r, int):
  49. if 0 <= r <= 3:
  50. return r
  51. raise ValueError(f"Invalid register value {r}")
  52. rl = r.lower()
  53. if rl.startswith('$'):
  54. rl = rl[1:]
  55. if rl == 'ra' or rl == 'r0':
  56. return 0
  57. if rl == 'rb' or rl == 'r1':
  58. return 1
  59. if rl == 'rc' or rl == 'r2':
  60. return 2
  61. if rl == 're' or rl == 'r3':
  62. return 3
  63. raise ValueError(f"Invalid register name '{r}'")
  64. class CompilingError(Exception):
  65. def __init__(self, message):
  66. self.message = message
  67. class InstructionError(Exception):
  68. def __init__(self, message):
  69. self.message = message
  70. class Instruction:
  71. def __init__(self, name: str, opcode: str, operands=0, alias=None):
  72. name = name.strip().lower()
  73. if not name or not name.isalnum():
  74. raise InstructionError(f"Invalid instruction name '{name}'")
  75. self.name = name.strip()
  76. self.alias = alias or []
  77. self.opcode = decode_bytes(opcode.replace('?', '0'))[0]
  78. self.reg_operands = 0
  79. if len(opcode) == 10:
  80. if opcode[6:8] == '??':
  81. self.reg_operands += 1
  82. if opcode[8:10] == '??':
  83. self.reg_operands += 1
  84. self.imm_operands = operands
  85. self.compiler = None
  86. @property
  87. def length(self):
  88. return self.imm_operands + 1
  89. def __len__(self):
  90. return self.length
  91. def _gen_instr(self, regs, imm):
  92. instr = self.opcode
  93. if len(regs) != self.reg_operands:
  94. raise CompilingError(f"Invalid number of registers: set {len(regs)}, required: {self.reg_operands}")
  95. limm = 0
  96. for i in imm:
  97. if isinstance(i, str):
  98. if i in self.compiler.labels:
  99. d = self.compiler.labels[i]
  100. limm += len(d)
  101. else:
  102. limm += self.compiler.address_size
  103. else:
  104. limm += len(i)
  105. if limm != self.imm_operands:
  106. raise CompilingError(f"Invalid number of immediate: set {limm}, required: {self.reg_operands}")
  107. if len(regs) == 2:
  108. if regs[1] is None:
  109. raise CompilingError(f"Unable to decode register name {regs[1]}")
  110. if regs[0] is None:
  111. raise CompilingError(f"Unable to decode register name {regs[0]}")
  112. instr |= regs[1] << 2 | regs[0]
  113. elif len(regs) == 1:
  114. if regs[0] is None:
  115. raise CompilingError(f"Unable to decode register name {regs[0]}")
  116. instr |= int(regs[0]) << 2
  117. return instr
  118. def compile(self, operands):
  119. regs = []
  120. imm = []
  121. for i, arg in enumerate(operands):
  122. if self.reg_operands > i:
  123. regs.append(self.compiler.decode_reg(arg))
  124. else:
  125. imm.append(self.compiler.decode_bytes(arg))
  126. instr = self._gen_instr(regs, imm)
  127. return [instr] + imm
  128. class CompObject:
  129. def __init__(self, instr, operands, line_num):
  130. self.instr = instr
  131. self.operands = operands
  132. self.line_num = line_num
  133. self.code = []
  134. self.code_ref = 0
  135. def compile(self):
  136. self.code = self.instr.compile(self.operands)
  137. return self.code
  138. class Compiler:
  139. def __init__(self, address_size=2, byte_order='little'):
  140. self.instr_db = {}
  141. self.data = []
  142. self.caddress = 0
  143. self.labels = {}
  144. self.order = byte_order
  145. self.regnames = {}
  146. self.address_size = address_size
  147. def decode_reg(self, s: str):
  148. s = s.strip()
  149. # if s in self.labels:
  150. # b = self.labels[s]
  151. if s in self.regnames:
  152. b = self.regnames[s]
  153. else:
  154. b = self.decode_bytes(s)
  155. if isinstance(b, bytes):
  156. i = int.from_bytes(b, byteorder=self.order)
  157. elif isinstance(b, int):
  158. i = b
  159. else:
  160. raise CompilingError(f"Unrecognised register name: {s}")
  161. if i not in self.regnames.values():
  162. raise CompilingError(f"Invalid register: {s}")
  163. return i
  164. def decode_bytes(self, s: str):
  165. s = s.strip()
  166. typ = ""
  167. # Decimal numbers
  168. if s.isnumeric():
  169. typ = 'int'
  170. elif s.endswith('d') and s[:-1].isnumeric():
  171. s = s[:-1]
  172. typ = 'int'
  173. elif s.startswith('0d') and s[2:].isnumeric():
  174. s = s[2:]
  175. typ = 'int'
  176. # Hexadecimal numbers
  177. elif s.startswith('0') and s.endswith('h') and match(hex_re, s[1:-1]):
  178. s = s[1:-1]
  179. typ = 'hex'
  180. elif (s.startswith('$0') or s.startswith('0x') or s.startswith('$0')) and match(hex_re, s[2:]):
  181. s = s[2:]
  182. typ = 'hex'
  183. # Octal numbers
  184. elif (s.endswith('q') or s.endswith('o')) and match(oct_re, s[:-1]):
  185. s = s[:-1]
  186. typ = 'oct'
  187. elif (s.startswith('0q') or s.startswith('0o')) and match(oct_re, s[2:]):
  188. s = s[2:]
  189. typ = 'oct'
  190. # Binary number
  191. elif (s.endswith('b') or s.endswith('y')) and match(bin_re, s[:-1]):
  192. s = s[:-1].replace('_', '')
  193. typ = 'bin'
  194. elif (s.startswith('0b') or s.startswith('0y')) and match(bin_re, s[2:]):
  195. s = s[2:].replace('_', '')
  196. typ = 'bin'
  197. # ASCII
  198. elif s.startswith("'") and s.endswith("'") and len(s) == 3:
  199. s = ord(s[1:-1]).to_bytes(1, self.order)
  200. typ = 'ascii'
  201. # Convert with limits
  202. if typ == 'int':
  203. numb = int(s)
  204. for i in range(1, 9):
  205. if -2 ** (i * 7) < i < 2 ** (i * 8):
  206. return numb.to_bytes(i, self.order)
  207. elif typ == 'hex':
  208. numb = int(s, 16)
  209. return numb.to_bytes(int(len(s) / 2) + len(s) % 2, self.order)
  210. elif typ == 'oct':
  211. numb = int(s, 8)
  212. for i in range(1, 9):
  213. if -2 ** (i * 7) < i < 2 ** (i * 8):
  214. return numb.to_bytes(i, self.order)
  215. elif typ == 'bin':
  216. numb = int(s, 2)
  217. return numb.to_bytes(int(len(s) / 8) + len(s) % 8, self.order)
  218. else:
  219. return s
  220. @staticmethod
  221. def _hash_instr(name, operands):
  222. return hash(name) + hash(operands)
  223. def add_reg(self, name, val):
  224. self.regnames[name] = val
  225. self.regnames['$' + name] = val
  226. def add_instr(self, instr: Instruction):
  227. instr.compiler = self
  228. operands = instr.reg_operands + instr.imm_operands
  229. # ihash = self._hash_instr(instr.name, operands)
  230. if instr.name in self.instr_db:
  231. raise InstructionError(f"Instruction {instr.name} operands={operands} duplicate!")
  232. self.instr_db[instr.name] = instr
  233. for alias in instr.alias:
  234. # ahash = self._hash_instr(alias, operands)
  235. if alias.lower() in self.instr_db:
  236. raise InstructionError(f"Instruction alias {alias} operands={operands} duplicate!")
  237. self.instr_db[alias.lower()] = instr
  238. def __func(self, f, args):
  239. for arg in args:
  240. if arg == '|':
  241. pass
  242. if arg == '^':
  243. pass
  244. if arg == '&':
  245. pass
  246. if arg == '<<':
  247. pass
  248. if arg == '>>':
  249. pass
  250. if arg == '+':
  251. pass
  252. if arg == '-':
  253. pass
  254. if arg == '*':
  255. pass
  256. if arg == '/' or arg == '//':
  257. pass
  258. if arg == '%' or arg == '%%':
  259. pass
  260. def __precompile(self, line):
  261. line = line.split(';', 1)[0]
  262. if ':' in line:
  263. linespl = line.split(':', 1)
  264. line = linespl[1]
  265. label = linespl[0]
  266. if label in self.labels:
  267. raise CompilingError(f"Label {label} duplicate")
  268. self.labels[label] = (self.caddress).to_bytes(self.address_size, self.order)
  269. if line.startswith('%define'):
  270. sp = list(filter(None, line.split(' ', 3)))
  271. if len(sp) != 3:
  272. raise CompilingError(f"Invalid %define")
  273. if '(' in sp[1] and ')' in sp[1]: # Function
  274. raise CompilingError(f"%define functions not implemented")
  275. self.labels[sp[1]] = self.decode_bytes(sp[2])
  276. return
  277. instr0 = list(filter(None, line.strip().split(' ', 1)))
  278. if len(instr0) == 0:
  279. return
  280. instr = instr0[0]
  281. if len(instr0) == 1:
  282. instr0.append('')
  283. operands = list(filter(None, map(lambda x: x.strip(), instr0[1].split(','))))
  284. if instr.lower() not in self.instr_db:
  285. raise CompilingError(f"Instruction {instr} operands={operands} is not recognised!")
  286. co = CompObject(self.instr_db[instr.lower()], operands, 0)
  287. return co
  288. def compile(self, file, code):
  289. failure = False
  290. instr = []
  291. binary = []
  292. for lnum, line in enumerate(code):
  293. lnum += 1
  294. try:
  295. co = self.__precompile(line)
  296. if co is not None:
  297. co.line_num = lnum
  298. self.caddress += co.instr.length
  299. instr.append(co)
  300. except CompilingError as e:
  301. failure = True
  302. print(f"ERROR {file}:{lnum}: {e.message}")
  303. for co in instr:
  304. try:
  305. binary += co.compile()
  306. except CompilingError as e:
  307. failure = True
  308. print(f"ERROR {file}:{co.line_num}: {e.message}")
  309. except Exception:
  310. failure = True
  311. print(f"ERROR {file}:{co.line_num}: Unexpected error:")
  312. traceback.print_exc()
  313. nbin = bytearray()
  314. for b in binary:
  315. if isinstance(b, int):
  316. nbin += b.to_bytes(1, self.order)
  317. elif isinstance(b, bytes):
  318. nbin += b
  319. elif isinstance(b, str):
  320. if b in self.labels:
  321. nbin += self.labels[b]
  322. else:
  323. failure = True
  324. print(f"ERROR {file}: Unable to find label '{b}'")
  325. if failure:
  326. return None
  327. return nbin
  328. def convert_to_binary(data):
  329. a = '\n'.join([format(i, '08b') for i in data])
  330. return a.encode()
  331. def convert_to_mem(data):
  332. x = b''
  333. fa = f'0{math.ceil(int(math.log2(len(data)))/4)}x'
  334. a = [format(d, '02x') for d in data]
  335. for i in range(int(len(a) / 8) + 1):
  336. y = a[i * 8:(i + 1) * 8]
  337. if len(y) > 0:
  338. x += (' '.join(y) + ' // ' + format(i*8, fa) + '\n').encode()
  339. return x
  340. def convert_to_mif(data, depth=32, width=8):
  341. x = f'''-- auto-generated memory initialisation file
  342. DEPTH = {depth};
  343. WIDTH = {width};
  344. ADDRESS_RADIX = HEX;
  345. DATA_RADIX = HEX;
  346. CONTENT
  347. BEGIN
  348. '''.encode()
  349. addr_format = f'0{math.ceil(int(math.log2(len(data)))/4)}x'
  350. a = [format(i, '02x') for i in data]
  351. for i in range(int(len(a) / 8) + 1):
  352. y = a[i * 8:(i + 1) * 8]
  353. if len(y) > 0:
  354. x += (format(i*8, addr_format) + ' : ' + ' '.join(y) + ';\n').encode()
  355. x += b"END;"
  356. return x