Miasm2
 All Classes Namespaces Files Functions Variables Typedefs Properties Macros
symbexec.py
Go to the documentation of this file.
1 import miasm2.expression.expression as m2_expr
2 from miasm2.expression.modint import int32
3 from miasm2.expression.simplifications import expr_simp
4 from miasm2.core import asmbloc
5 import logging
6 
7 
8 log = logging.getLogger("symbexec")
9 console_handler = logging.StreamHandler()
10 console_handler.setFormatter(logging.Formatter("%(levelname)-5s: %(message)s"))
11 log.addHandler(console_handler)
12 log.setLevel(logging.INFO)
13 
14 
15 class symbols():
16 
17  def __init__(self, init=None):
18  if init is None:
19  init = {}
20  self.symbols_id = {}
21  self.symbols_mem = {}
22  for k, v in init.items():
23  self[k] = v
24 
25  def __contains__(self, a):
26  if not isinstance(a, m2_expr.ExprMem):
27  return self.symbols_id.__contains__(a)
28  if not self.symbols_mem.__contains__(a.arg):
29  return False
30  return self.symbols_mem[a.arg][0].size == a.size
31 
32  def __getitem__(self, a):
33  if not isinstance(a, m2_expr.ExprMem):
34  return self.symbols_id.__getitem__(a)
35  if not a.arg in self.symbols_mem:
36  raise KeyError(a)
37  m = self.symbols_mem.__getitem__(a.arg)
38  if m[0].size != a.size:
39  raise KeyError(a)
40  return m[1]
41 
42  def __setitem__(self, a, v):
43  if not isinstance(a, m2_expr.ExprMem):
44  self.symbols_id.__setitem__(a, v)
45  return
46  self.symbols_mem.__setitem__(a.arg, (a, v))
47 
48  def __iter__(self):
49  for a in self.symbols_id:
50  yield a
51  for a in self.symbols_mem:
52  yield self.symbols_mem[a][0]
53 
54  def __delitem__(self, a):
55  if not isinstance(a, m2_expr.ExprMem):
56  self.symbols_id.__delitem__(a)
57  else:
58  self.symbols_mem.__delitem__(a.arg)
59 
60  def items(self):
61  k = self.symbols_id.items() + [x for x in self.symbols_mem.values()]
62  return k
63 
64  def keys(self):
65  k = self.symbols_id.keys() + [x[0] for x in self.symbols_mem.values()]
66  return k
67 
68  def copy(self):
69  p = symbols()
70  p.symbols_id = dict(self.symbols_id)
71  p.symbols_mem = dict(self.symbols_mem)
72  return p
73 
74  def inject_info(self, info):
75  s = symbols()
76  for k, v in self.items():
77  k = expr_simp(k.replace_expr(info))
78  v = expr_simp(v.replace_expr(info))
79  s[k] = v
80  return s
81 
82 
84 
85  def __init__(self, ir_arch, known_symbols,
86  func_read=None,
87  func_write=None,
88  sb_expr_simp=expr_simp):
89  self.symbols = symbols()
90  for k, v in known_symbols.items():
91  self.symbols[k] = v
92  self.func_read = func_read
93  self.func_write = func_write
94  self.ir_arch = ir_arch
95  self.expr_simp = sb_expr_simp
96 
97  def find_mem_by_addr(self, e):
98  if e in self.symbols.symbols_mem:
99  return self.symbols.symbols_mem[e][0]
100  return None
101 
102  def eval_ExprId(self, e, eval_cache=None):
103  if eval_cache is None:
104  eval_cache = {}
105  if isinstance(e.name, asmbloc.asm_label) and e.name.offset is not None:
106  return m2_expr.ExprInt_from(e, e.name.offset)
107  if not e in self.symbols:
108  # raise ValueError('unknown symbol %s'% e)
109  return e
110  return self.symbols[e]
111 
112  def eval_ExprInt(self, e, eval_cache=None):
113  return e
114 
115  def eval_ExprMem(self, e, eval_cache=None):
116  if eval_cache is None:
117  eval_cache = {}
118  a_val = self.expr_simp(self.eval_expr(e.arg, eval_cache))
119  if a_val != e.arg:
120  a = self.expr_simp(m2_expr.ExprMem(a_val, size=e.size))
121  else:
122  a = e
123  if a in self.symbols:
124  return self.symbols[a]
125  tmp = None
126  # test if mem lookup is known
127  if a_val in self.symbols.symbols_mem:
128  tmp = self.symbols.symbols_mem[a_val][0]
129  if tmp is None:
130 
131  v = self.find_mem_by_addr(a_val)
132  if not v:
133  out = []
134  ov = self.get_mem_overlapping(a, eval_cache)
135  off_base = 0
136  ov.sort()
137  # ov.reverse()
138  for off, x in ov:
139  # off_base = off * 8
140  # x_size = self.symbols[x].size
141  if off >= 0:
142  m = min(a.size - off * 8, x.size)
143  ee = m2_expr.ExprSlice(self.symbols[x], 0, m)
144  ee = self.expr_simp(ee)
145  out.append((ee, off_base, off_base + m))
146  off_base += m
147  else:
148  m = min(a.size - off * 8, x.size)
149  ee = m2_expr.ExprSlice(self.symbols[x], -off * 8, m)
150  ff = self.expr_simp(ee)
151  new_off_base = off_base + m + off * 8
152  out.append((ff, off_base, new_off_base))
153  off_base = new_off_base
154  if out:
155  missing_slice = self.rest_slice(out, 0, a.size)
156  for sa, sb in missing_slice:
157  ptr = self.expr_simp(
158  a_val + m2_expr.ExprInt_from(a_val, sa / 8)
159  )
160  mm = m2_expr.ExprMem(ptr, size=sb - sa)
161  mm.is_term = True
162  mm.is_simp = True
163  out.append((mm, sa, sb))
164  out.sort(key=lambda x: x[1])
165  # for e, sa, sb in out:
166  # print str(e), sa, sb
167  ee = m2_expr.ExprSlice(m2_expr.ExprCompose(out), 0, a.size)
168  ee = self.expr_simp(ee)
169  return ee
170  if self.func_read and isinstance(a.arg, m2_expr.ExprInt):
171  return self.func_read(a)
172  else:
173  # XXX hack test
174  a.is_term = True
175  return a
176  # bigger lookup
177  if a.size > tmp.size:
178  rest = a.size
179  ptr = a_val
180  out = []
181  ptr_index = 0
182  while rest:
183  v = self.find_mem_by_addr(ptr)
184  if v is None:
185  # raise ValueError("cannot find %s in mem"%str(ptr))
186  val = m2_expr.ExprMem(ptr, 8)
187  v = val
188  diff_size = 8
189  elif rest >= v.size:
190  val = self.symbols[v]
191  diff_size = v.size
192  else:
193  diff_size = rest
194  val = self.symbols[v][0:diff_size]
195  val = (val, ptr_index, ptr_index + diff_size)
196  out.append(val)
197  ptr_index += diff_size
198  rest -= diff_size
199  ptr = self.expr_simp(
200  self.eval_expr(
201  m2_expr.ExprOp('+', ptr,
202  m2_expr.ExprInt_from(ptr, v.size / 8)),
203  eval_cache)
204  )
205  e = self.expr_simp(m2_expr.ExprCompose(out))
206  return e
207  # part lookup
208  tmp = self.expr_simp(m2_expr.ExprSlice(self.symbols[tmp], 0, a.size))
209  return tmp
210 
211  def eval_expr_visit(self, e, eval_cache=None):
212  if eval_cache is None:
213  eval_cache = {}
214  # print 'visit', e, e.is_term
215  if e.is_term:
216  return e
217  if e in eval_cache:
218  return eval_cache[e]
219  c = e.__class__
220  deal_class = {m2_expr.ExprId: self.eval_ExprId,
221  m2_expr.ExprInt: self.eval_ExprInt,
222  m2_expr.ExprMem: self.eval_ExprMem,
223  }
224  # print 'eval', e
225  if c in deal_class:
226  e = deal_class[c](e, eval_cache)
227  # print "ret", e
228  if not (isinstance(e, m2_expr.ExprId) or isinstance(e,
229  m2_expr.ExprInt)):
230  e.is_term = True
231  return e
232 
233  def eval_expr(self, e, eval_cache=None):
234  if eval_cache is None:
235  eval_cache = {}
236  r = e.visit(lambda x: self.eval_expr_visit(x, eval_cache))
237  return r
238 
239  def modified_regs(self, init_state=None):
240  if init_state is None:
241  init_state = self.ir_arch.arch.regs.regs_init
242  ids = self.symbols.symbols_id.keys()
243  ids.sort()
244  for i in ids:
245  if i in init_state and \
246  i in self.symbols.symbols_id and \
247  self.symbols.symbols_id[i] == init_state[i]:
248  continue
249  yield i
250 
251  def modified_mems(self, init_state=None):
252  mems = self.symbols.symbols_mem.values()
253  mems.sort()
254  for m, _ in mems:
255  yield m
256 
257  def modified(self, init_state=None):
258  for r in self.modified_regs(init_state):
259  yield r
260  for m in self.modified_mems(init_state):
261  yield m
262 
263  def dump_id(self):
264  ids = self.symbols.symbols_id.keys()
265  ids.sort()
266  for i in ids:
267  if i in self.ir_arch.arch.regs.regs_init and \
268  i in self.symbols.symbols_id and \
269  self.symbols.symbols_id[i] == self.ir_arch.arch.regs.regs_init[i]:
270  continue
271  print i, self.symbols.symbols_id[i]
272 
273  def dump_mem(self):
274  mems = self.symbols.symbols_mem.values()
275  mems.sort()
276  for m, v in mems:
277  print m, v
278 
279  def rest_slice(self, slices, start, stop):
280  o = []
281  last = start
282  for _, a, b in slices:
283  if a == last:
284  last = b
285  continue
286  o.append((last, a))
287  last = b
288  if last != stop:
289  o.append((b, stop))
290  return o
291 
292  def substract_mems(self, a, b):
293  ex = b.arg - a.arg
294  ex = self.expr_simp(self.eval_expr(ex, {}))
295  if not isinstance(ex, m2_expr.ExprInt):
296  return None
297  ptr_diff = int(int32(ex.arg))
298  out = []
299  if ptr_diff < 0:
300  # [a ]
301  #[b ]XXX
302  sub_size = b.size + ptr_diff * 8
303  if sub_size >= a.size:
304  pass
305  else:
306  ex = m2_expr.ExprOp('+', a.arg,
307  m2_expr.ExprInt_from(a.arg, sub_size / 8))
308  ex = self.expr_simp(self.eval_expr(ex, {}))
309 
310  rest_ptr = ex
311  rest_size = a.size - sub_size
312 
313  val = self.symbols[a][sub_size:a.size]
314  out = [(m2_expr.ExprMem(rest_ptr, rest_size), val)]
315  else:
316  #[a ]
317  # XXXX[b ]YY
318 
319  #[a ]
320  # XXXX[b ]
321 
322  out = []
323  # part X
324  if ptr_diff > 0:
325  val = self.symbols[a][0:ptr_diff * 8]
326  out.append((m2_expr.ExprMem(a.arg, ptr_diff * 8), val))
327  # part Y
328  if ptr_diff * 8 + b.size < a.size:
329 
330  ex = m2_expr.ExprOp('+', b.arg,
331  m2_expr.ExprInt_from(b.arg, b.size / 8))
332  ex = self.expr_simp(self.eval_expr(ex, {}))
333 
334  rest_ptr = ex
335  rest_size = a.size - (ptr_diff * 8 + b.size)
336  val = self.symbols[a][ptr_diff * 8 + b.size:a.size]
337  out.append((m2_expr.ExprMem(ex, val.size), val))
338  return out
339 
340  # give mem stored overlapping requested mem ptr
341  def get_mem_overlapping(self, e, eval_cache=None):
342  if eval_cache is None:
343  eval_cache = {}
344  if not isinstance(e, m2_expr.ExprMem):
345  raise ValueError('mem overlap bad arg')
346  ov = []
347  # suppose max mem size is 64 bytes, compute all reachable addresses
348  to_test = []
349  base_ptr = self.expr_simp(e.arg)
350  for i in xrange(-7, e.size / 8):
351  ex = self.expr_simp(
352  self.eval_expr(base_ptr + m2_expr.ExprInt_from(e.arg, i),
353  eval_cache))
354  to_test.append((i, ex))
355 
356  for i, x in to_test:
357  if not x in self.symbols.symbols_mem:
358  continue
359  ex = self.expr_simp(self.eval_expr(e.arg - x, eval_cache))
360  if not isinstance(ex, m2_expr.ExprInt):
361  raise ValueError('ex is not ExprInt')
362  ptr_diff = int32(ex.arg)
363  if ptr_diff >= self.symbols.symbols_mem[x][1].size / 8:
364  # print "too long!"
365  continue
366  ov.append((i, self.symbols.symbols_mem[x][0]))
367  return ov
368 
369  def eval_ir_expr(self, exprs):
370  pool_out = {}
371 
372  eval_cache = dict(self.symbols.items())
373 
374  for e in exprs:
375  if not isinstance(e, m2_expr.ExprAff):
376  raise TypeError('not affect', str(e))
377 
378  src = self.eval_expr(e.src, eval_cache)
379  if isinstance(e.dst, m2_expr.ExprMem):
380  a = self.eval_expr(e.dst.arg, eval_cache)
381  a = self.expr_simp(a)
382  # search already present mem
383  tmp = None
384  # test if mem lookup is known
385  tmp = m2_expr.ExprMem(a, e.dst.size)
386  dst = tmp
387  if self.func_write and isinstance(dst.arg, m2_expr.ExprInt):
388  self.func_write(self, dst, src, pool_out)
389  else:
390  pool_out[dst] = src
391 
392  elif isinstance(e.dst, m2_expr.ExprId):
393  pool_out[e.dst] = src
394  else:
395  raise ValueError("affected zarb", str(e.dst))
396 
397  return pool_out.items()
398 
399  def eval_ir(self, ir):
400  mem_dst = []
401  # src_dst = [(x.src, x.dst) for x in ir]
402  src_dst = self.eval_ir_expr(ir)
403  eval_cache = dict(self.symbols.items())
404  for dst, src in src_dst:
405  if isinstance(dst, m2_expr.ExprMem):
406  mem_overlap = self.get_mem_overlapping(dst, eval_cache)
407  for _, base in mem_overlap:
408  diff_mem = self.substract_mems(base, dst)
409  del self.symbols[base]
410  for new_mem, new_val in diff_mem:
411  new_val.is_term = True
412  self.symbols[new_mem] = new_val
413  src_o = self.expr_simp(src)
414  # print 'SRCo', src_o
415  # src_o.is_term = True
416  self.symbols[dst] = src_o
417  if isinstance(dst, m2_expr.ExprMem):
418  mem_dst.append(dst)
419  return mem_dst
420 
421  def emulbloc(self, bloc_ir, step=False):
422  for ir in bloc_ir.irs:
423  self.eval_ir(ir)
424  if step:
425  print '_' * 80
426  self.dump_id()
427  eval_cache = dict(self.symbols.items())
428  return self.eval_expr(self.ir_arch.IRDst, eval_cache)
429 
430  def emul_ir_bloc(self, myir, ad, step=False):
431  b = myir.get_bloc(ad)
432  if b is not None:
433  ad = self.emulbloc(b, step=step)
434  return ad
435 
436  def emul_ir_blocs(self, myir, ad, lbl_stop=None, step=False):
437  while True:
438  b = myir.get_bloc(ad)
439  if b is None:
440  break
441  if b.label == lbl_stop:
442  break
443  ad = self.emulbloc(b, step=step)
444  return ad
445 
446  def del_mem_above_stack(self, sp):
447  sp_val = self.symbols[sp]
448  for mem_ad, (mem, _) in self.symbols.symbols_mem.items():
449  # print mem_ad, sp_val
450  diff = self.eval_expr(mem_ad - sp_val, {})
451  diff = expr_simp(diff)
452  if not isinstance(diff, m2_expr.ExprInt):
453  continue
454  m = expr_simp(diff.msb())
455  if m.arg == 1:
456  del self.symbols[mem]
457