miasm2.expression.simplifications_common module
# ----------------------------- # # Common simplifications passes # # ----------------------------- # from miasm2.expression.modint import mod_size2int, mod_size2uint from miasm2.expression.expression import ExprInt, ExprSlice, ExprMem, ExprCond, ExprOp, ExprCompose from miasm2.expression.expression_helper import parity, op_propag_cst, merge_sliceto_slice def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # substraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] stop = final_size filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] start = 0 filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args) def simp_cond_op_int(e_s, expr): "Extract conditions from operations" # x?a:b + x?c:d + e => x?(a+c+e:b+d+e) if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: return expr if len(expr.args) < 2: return expr conds = set() for arg in expr.args: if arg.is_cond(): conds.add(arg) if len(conds) != 1: return expr cond = list(conds).pop() args1, args2 = [], [] for arg in expr.args: if arg.is_cond(): args1.append(arg.src1) args2.append(arg.src2) else: args1.append(arg) args2.append(arg) return ExprCond(cond.cond, ExprOp(expr.op, *args1), ExprOp(expr.op, *args2)) def simp_cond_factor(e_s, expr): "Merge similar conditions" if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: return expr if len(expr.args) < 2: return expr if expr.op in ['>>', '<<', 'a>>']: assert len(expr.args) == 2 # Note: the following code is correct for non-commutative operation only if # there is 2 arguments. Otherwise, the order is not conserved # Regroup sub-expression by similar conditions conds = {} not_conds = [] multi_cond = False for arg in expr.args: if not arg.is_cond(): not_conds.append(arg) continue cond = arg.cond if not cond in conds: conds[cond] = [] else: multi_cond = True conds[cond].append(arg) if not multi_cond: return expr # Rebuild the new expression c_out = not_conds for cond, vals in conds.items(): new_src1 = [x.src1 for x in vals] new_src2 = [x.src2 for x in vals] src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) c_out.append(ExprCond(cond, src1, src2)) if len(c_out) == 1: new_e = c_out[0] else: new_e = ExprOp(expr.op, *c_out) return new_e def simp_slice(e_s, expr): "Slice optimization" # slice(A, 0, a.size) => A if expr.start == 0 and expr.stop == expr.arg.size: return expr.arg # Slice(int) => int if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: raise ValueError('slice in slice: getting more val', str(expr)) return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, expr.start + expr.arg.start + (expr.stop - expr.start)) if expr.arg.is_compose(): # Slice(Compose(A), x) => Slice(A, y) for index, arg in expr.arg.iter_args(): if index <= expr.start and index+arg.size >= expr.stop: return arg[expr.start - index:expr.stop - index] # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] for index, arg in expr.arg.iter_args(): # arg is before slice start if expr.start >= index + arg.size: continue # arg is after slice stop elif expr.stop <= index: continue # arg is fully included in slice elif expr.start <= index and index + arg.size <= expr.stop: out.append(arg) continue # arg is truncated at start if expr.start > index: slice_start = expr.start - index else: # arg is not truncated at start slice_start = 0 # a is truncated at stop if expr.stop < index + arg.size: slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start else: slice_stop = arg.size out.append(arg[slice_start:slice_stop]) return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? if (expr.arg.is_mem() and expr.start == 0 and expr.arg.size > expr.stop and expr.stop % 8 == 0): return ExprMem(expr.arg.arg, size=expr.stop) # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) if (expr.arg.is_cond() and (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and expr.arg.args[1].is_int()): arg, shift = expr.arg.args shift = int(shift) if expr.arg.op == ">>": if shift + expr.stop <= arg.size: return arg[expr.start + shift:expr.stop + shift] elif expr.arg.op == "<<": if expr.start - shift >= 0: return arg[expr.start - shift:expr.stop - shift] else: raise ValueError('Bad case') return expr def simp_compose(e_s, expr): "Commons simplification on ExprCompose" args = merge_sliceto_slice(expr) out = [] # compose of compose for arg in args: if arg.is_compose(): out += arg.args else: out.append(arg) args = out # Compose(a) with a.size = compose.size => a if len(args) == 1 and args[0].size == expr.size: return args[0] # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) if len(args) == 2 and args[1].is_int(0): if (args[0].is_slice() and args[0].stop == args[0].arg.size and args[0].size + args[1].size == args[0].arg.size): new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) return new_expr # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i] for i, arg in enumerate(args[:-1]): nxt = args[i + 1] if arg.is_mem() and nxt.is_mem(): gap = e_s(nxt.arg - arg.arg) if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size / 8: args = args[:i] + [ExprMem(arg.arg, arg.size + nxt.size)] + args[i + 2:] return ExprCompose(*args) # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f} conds = set(arg.cond for arg in expr.args if arg.is_cond()) if len(conds) == 1: cond = list(conds)[0] args1, args2 = [], [] for arg in expr.args: if arg.is_cond(): args1.append(arg.src1) args2.append(arg.src2) else: args1.append(arg) args2.append(arg) arg1 = e_s(ExprCompose(*args1)) arg2 = e_s(ExprCompose(*args2)) return ExprCond(cond, arg1, arg2) return ExprCompose(*args) def simp_cond(e_s, expr): "Common simplifications on ExprCond" # eval exprcond src1/src2 with satifiable/unsatisfiable condition # propagation if (not expr.cond.is_int()) and expr.cond.size == 1: src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) if src1 != expr.src1 or src2 != expr.src2: return ExprCond(expr.cond, src1, src2) # -A ? B:C => A ? B:C if expr.cond.is_op('-') and len(expr.cond.args) == 1: expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2) # a?x:x elif expr.src1 == expr.src2: expr = expr.src1 # int ? A:B => A or B elif expr.cond.is_int(): if expr.cond.arg == 0: expr = expr.src2 else: expr = expr.src1 # a?(a?b:c):x => a?b:x elif expr.src1.is_cond() and expr.cond == expr.src1.cond: expr = ExprCond(expr.cond, expr.src1.src1, expr.src2) # a?x:(a?b:c) => a?x:c elif expr.src2.is_cond() and expr.cond == expr.src2.cond: expr = ExprCond(expr.cond, expr.src1, expr.src2.src2) # a|int ? b:c => b with int != 0 elif (expr.cond.is_op('|') and expr.cond.args[1].is_int() and expr.cond.args[1].arg != 0): return expr.src1 # (C?int1:int2)?(A:B) => elif (expr.cond.is_cond() and expr.cond.src1.is_int() and expr.cond.src2.is_int()): int1 = expr.cond.src1.arg.arg int2 = expr.cond.src2.arg.arg if int1 and int2: expr = expr.src1 elif int1 == 0 and int2 == 0: expr = expr.src2 elif int1 == 0 and int2: expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) elif int1 and int2 == 0: expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) elif expr.cond.is_compose(): # {0, X, 0}?(A:B) => X?(A:B) args = [arg for arg in expr.cond.args if not arg.is_int(0)] if len(args) == 1: arg = args.pop() return ExprCond(arg, expr.src1, expr.src2) elif len(args) < len(expr.cond.args): return ExprCond(ExprCompose(*args), expr.src1, expr.src2) return expr def simp_mem(e_s, expr): "Common simplifications on ExprMem" # @32[x?a:b] => x?@32[a]:@32[b] if expr.arg.is_cond(): cond = expr.arg ret = ExprCond(cond.cond, ExprMem(cond.src1, expr.size), ExprMem(cond.src2, expr.size)) return ret return expr def test_cc_eq_args(expr, *sons_op): if not expr.is_op(): return False if len(expr.args) != len(sons_op): return False all_args = set() for i, arg in enumerate(expr.args): if not arg.is_op(sons_op[i]): return False all_args.add(arg.args) return len(all_args) == 1 def simp_cc_conds(expr_simp, expr): if (expr.is_op("CC_U>=") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprCond( ExprOp("<u", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_U<") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprOp("<u", *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprOp("<s", *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprCond( ExprOp("<s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprOp("==", arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprCond( ExprOp("==",arg, ExprInt(0, arg.size)), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprCond( ExprOp("==", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprOp("==", *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprCond( ExprOp("&", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp("<=s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond( ExprOp("<=s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprCond( ExprOp("<s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprOp("<s", *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp("<=s", *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp("<=s", *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp("<=u", *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp("<=u", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF" )): arg0, arg1 = expr.args[0].args expr = ExprOp("<s", arg0, -arg1) return expr def simp_cond_flag(expr_simp, expr): # FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B cond = expr.cond if cond.is_op("FLAG_EQ_CMP"): return ExprCond(ExprOp("==", *cond.args), expr.src1, expr.src2) return expr def simp_cond_int(expr_simp, expr): if (expr.cond.is_op('==') and expr.cond.args[1].is_int() and expr.cond.args[0].is_compose() and len(expr.cond.args[0].args) == 2 and expr.cond.args[0].args[1].is_int(0)): # ({X, 0} == int) => X == int[:] src = expr.cond.args[0].args[0] int_val = int(expr.cond.args[1]) new_int = ExprInt(int_val, src.size) expr = expr_simp(ExprCond(ExprOp("==", src, new_int), expr.src1, expr.src2)) elif (expr.cond.is_op() and expr.cond.op in ['==', '<s', '<=s', '<u', '<=u'] and expr.cond.args[1].is_int() and expr.cond.args[0].is_op("+") and expr.cond.args[0].args[-1].is_int()): # X + int1 == int2 => X == int2-int1 left, right = expr.cond.args left, int_diff = left.args[:-1], left.args[-1] if len(left) == 1: left = left[0] else: left = ExprOp('+', *left) new_int = expr_simp(right - int_diff) expr = expr_simp(ExprCond(ExprOp(expr.cond.op, left, new_int), expr.src1, expr.src2)) return expr def simp_cmp_int_arg(expr_simp, expr): """ (0x10 <= R0) ? A:B => (R0 < 0x10) ? B:A """ cond = expr.cond if not cond.is_op(): return expr op = cond.op if op not in ['==', '<s', '<=s', '<u', '<=u']: return expr arg1, arg2 = cond.args if arg2.is_int(): return expr if not arg1.is_int(): return expr src1, src2 = expr.src1, expr.src2 if op == "==": return ExprCond(ExprOp('==', arg2, arg1), src1, src2) arg1, arg2 = arg2, arg1 src1, src2 = src2, src1 if op == '<s': op = '<=s' elif op == '<=s': op = '<s' elif op == '<u': op = '<=u' elif op == '<=u': op = '<u' return ExprCond(ExprOp(op, arg1, arg2), src1, src2) def simp_subwc_cf(expr_s, expr): # SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D}) if not expr.is_op('FLAG_SUBWC_CF'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SUB_CF", op1, op2) def simp_subwc_of(expr_s, expr): # SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D}) if not expr.is_op('FLAG_SUBWC_OF'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SUB_OF", op1, op2) def simp_sign_subwc_cf(expr_s, expr): # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D}) if not expr.is_op('FLAG_SIGN_SUBWC'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SIGN_SUB", op1, op2) def simp_zeroext_eq_cst(expr_s, expr): # A.zeroExt(X) == int => A == int[:A.size] if not expr.is_op("=="): return expr arg1, arg2 = expr.args if not arg2.is_int(): return expr if not (arg1.is_op() and arg1.op.startswith("zeroExt")): return expr src = arg1.args[0] if int(arg2) > (1 << src.size): # Always false return ExprInt(0, 1) return ExprOp("==", src, ExprInt(int(arg2), src.size))
Module variables
var mod_size2int
var mod_size2uint
var op_propag_cst
Functions
def simp_cc_conds(
expr_simp, expr)
def simp_cc_conds(expr_simp, expr): if (expr.is_op("CC_U>=") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprCond( ExprOp("<u", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1)) elif (expr.is_op("CC_U<") and test_cc_eq_args( expr, "FLAG_SUB_CF" )): expr = ExprOp("<u", *expr.args[0].args) elif (expr.is_op("CC_NEG") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprOp("<s", *expr.args[0].args) elif (expr.is_op("CC_POS") and test_cc_eq_args( expr, "FLAG_SIGN_SUB" )): expr = ExprCond( ExprOp("<s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprOp("==", arg, ExprInt(0, arg.size)) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ" )): arg = expr.args[0].args[0] expr = ExprCond( ExprOp("==",arg, ExprInt(0, arg.size)), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprCond( ExprOp("==", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_CMP" )): expr = ExprOp("==", *expr.args[0].args) elif (expr.is_op("CC_NE") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprOp("&", *expr.args[0].args) elif (expr.is_op("CC_EQ") and test_cc_eq_args( expr, "FLAG_EQ_AND" )): expr = ExprCond( ExprOp("&", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp("<=s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprCond( ExprOp("<=s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S>=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprCond( ExprOp("<s", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF" )): expr = ExprOp("<s", *expr.args[0].args) elif (expr.is_op("CC_S<=") and test_cc_eq_args( expr, "FLAG_SIGN_SUB", "FLAG_SUB_OF", "FLAG_EQ_CMP", )): expr = ExprOp("<=s", *expr.args[0].args) elif (expr.is_op("CC_S<=") and len(expr.args) == 3 and expr.args[0].is_op("FLAG_SIGN_SUB") and expr.args[2].is_op("FLAG_EQ_CMP") and expr.args[0].args == expr.args[2].args and expr.args[1].is_int(0)): expr = ExprOp("<=s", *expr.args[0].args) elif (expr.is_op("CC_U<=") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprOp("<=u", *expr.args[0].args) elif (expr.is_op("CC_U>") and test_cc_eq_args( expr, "FLAG_SUB_CF", "FLAG_EQ_CMP", )): expr = ExprCond( ExprOp("<=u", *expr.args[0].args), ExprInt(0, 1), ExprInt(1, 1) ) elif (expr.is_op("CC_S<") and test_cc_eq_args( expr, "FLAG_SIGN_ADD", "FLAG_ADD_OF" )): arg0, arg1 = expr.args[0].args expr = ExprOp("<s", arg0, -arg1) return expr
def simp_cmp_int_arg(
expr_simp, expr)
(0x10 <= R0) ? A:B => (R0 < 0x10) ? B:A
def simp_cmp_int_arg(expr_simp, expr): """ (0x10 <= R0) ? A:B => (R0 < 0x10) ? B:A """ cond = expr.cond if not cond.is_op(): return expr op = cond.op if op not in ['==', '<s', '<=s', '<u', '<=u']: return expr arg1, arg2 = cond.args if arg2.is_int(): return expr if not arg1.is_int(): return expr src1, src2 = expr.src1, expr.src2 if op == "==": return ExprCond(ExprOp('==', arg2, arg1), src1, src2) arg1, arg2 = arg2, arg1 src1, src2 = src2, src1 if op == '<s': op = '<=s' elif op == '<=s': op = '<s' elif op == '<u': op = '<=u' elif op == '<=u': op = '<u' return ExprCond(ExprOp(op, arg1, arg2), src1, src2)
def simp_compose(
e_s, expr)
Commons simplification on ExprCompose
def simp_compose(e_s, expr): "Commons simplification on ExprCompose" args = merge_sliceto_slice(expr) out = [] # compose of compose for arg in args: if arg.is_compose(): out += arg.args else: out.append(arg) args = out # Compose(a) with a.size = compose.size => a if len(args) == 1 and args[0].size == expr.size: return args[0] # {(X[z:], 0, X.size-z), (0, X.size-z, X.size)} => (X >> z) if len(args) == 2 and args[1].is_int(0): if (args[0].is_slice() and args[0].stop == args[0].arg.size and args[0].size + args[1].size == args[0].arg.size): new_expr = args[0].arg >> ExprInt(args[0].start, args[0].arg.size) return new_expr # {@X[base + i] 0 X, @Y[base + i + X] X (X + Y)} => @(X+Y)[base + i] for i, arg in enumerate(args[:-1]): nxt = args[i + 1] if arg.is_mem() and nxt.is_mem(): gap = e_s(nxt.arg - arg.arg) if gap.is_int() and arg.size % 8 == 0 and int(gap) == arg.size / 8: args = args[:i] + [ExprMem(arg.arg, arg.size + nxt.size)] + args[i + 2:] return ExprCompose(*args) # {a, x?b:d, x?c:e, f} => x?{a, b, c, f}:{a, d, e, f} conds = set(arg.cond for arg in expr.args if arg.is_cond()) if len(conds) == 1: cond = list(conds)[0] args1, args2 = [], [] for arg in expr.args: if arg.is_cond(): args1.append(arg.src1) args2.append(arg.src2) else: args1.append(arg) args2.append(arg) arg1 = e_s(ExprCompose(*args1)) arg2 = e_s(ExprCompose(*args2)) return ExprCond(cond, arg1, arg2) return ExprCompose(*args)
def simp_cond(
e_s, expr)
Common simplifications on ExprCond
def simp_cond(e_s, expr): "Common simplifications on ExprCond" # eval exprcond src1/src2 with satifiable/unsatisfiable condition # propagation if (not expr.cond.is_int()) and expr.cond.size == 1: src1 = expr.src1.replace_expr({expr.cond: ExprInt(1, 1)}) src2 = expr.src2.replace_expr({expr.cond: ExprInt(0, 1)}) if src1 != expr.src1 or src2 != expr.src2: return ExprCond(expr.cond, src1, src2) # -A ? B:C => A ? B:C if expr.cond.is_op('-') and len(expr.cond.args) == 1: expr = ExprCond(expr.cond.args[0], expr.src1, expr.src2) # a?x:x elif expr.src1 == expr.src2: expr = expr.src1 # int ? A:B => A or B elif expr.cond.is_int(): if expr.cond.arg == 0: expr = expr.src2 else: expr = expr.src1 # a?(a?b:c):x => a?b:x elif expr.src1.is_cond() and expr.cond == expr.src1.cond: expr = ExprCond(expr.cond, expr.src1.src1, expr.src2) # a?x:(a?b:c) => a?x:c elif expr.src2.is_cond() and expr.cond == expr.src2.cond: expr = ExprCond(expr.cond, expr.src1, expr.src2.src2) # a|int ? b:c => b with int != 0 elif (expr.cond.is_op('|') and expr.cond.args[1].is_int() and expr.cond.args[1].arg != 0): return expr.src1 # (C?int1:int2)?(A:B) => elif (expr.cond.is_cond() and expr.cond.src1.is_int() and expr.cond.src2.is_int()): int1 = expr.cond.src1.arg.arg int2 = expr.cond.src2.arg.arg if int1 and int2: expr = expr.src1 elif int1 == 0 and int2 == 0: expr = expr.src2 elif int1 == 0 and int2: expr = ExprCond(expr.cond.cond, expr.src2, expr.src1) elif int1 and int2 == 0: expr = ExprCond(expr.cond.cond, expr.src1, expr.src2) elif expr.cond.is_compose(): # {0, X, 0}?(A:B) => X?(A:B) args = [arg for arg in expr.cond.args if not arg.is_int(0)] if len(args) == 1: arg = args.pop() return ExprCond(arg, expr.src1, expr.src2) elif len(args) < len(expr.cond.args): return ExprCond(ExprCompose(*args), expr.src1, expr.src2) return expr
def simp_cond_factor(
e_s, expr)
Merge similar conditions
def simp_cond_factor(e_s, expr): "Merge similar conditions" if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: return expr if len(expr.args) < 2: return expr if expr.op in ['>>', '<<', 'a>>']: assert len(expr.args) == 2 # Note: the following code is correct for non-commutative operation only if # there is 2 arguments. Otherwise, the order is not conserved # Regroup sub-expression by similar conditions conds = {} not_conds = [] multi_cond = False for arg in expr.args: if not arg.is_cond(): not_conds.append(arg) continue cond = arg.cond if not cond in conds: conds[cond] = [] else: multi_cond = True conds[cond].append(arg) if not multi_cond: return expr # Rebuild the new expression c_out = not_conds for cond, vals in conds.items(): new_src1 = [x.src1 for x in vals] new_src2 = [x.src2 for x in vals] src1 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src1)) src2 = e_s.expr_simp_wrapper(ExprOp(expr.op, *new_src2)) c_out.append(ExprCond(cond, src1, src2)) if len(c_out) == 1: new_e = c_out[0] else: new_e = ExprOp(expr.op, *c_out) return new_e
def simp_cond_flag(
expr_simp, expr)
def simp_cond_flag(expr_simp, expr): # FLAG_EQ_CMP(X, Y)?A:B => (X == Y)?A:B cond = expr.cond if cond.is_op("FLAG_EQ_CMP"): return ExprCond(ExprOp("==", *cond.args), expr.src1, expr.src2) return expr
def simp_cond_int(
expr_simp, expr)
def simp_cond_int(expr_simp, expr): if (expr.cond.is_op('==') and expr.cond.args[1].is_int() and expr.cond.args[0].is_compose() and len(expr.cond.args[0].args) == 2 and expr.cond.args[0].args[1].is_int(0)): # ({X, 0} == int) => X == int[:] src = expr.cond.args[0].args[0] int_val = int(expr.cond.args[1]) new_int = ExprInt(int_val, src.size) expr = expr_simp(ExprCond(ExprOp("==", src, new_int), expr.src1, expr.src2)) elif (expr.cond.is_op() and expr.cond.op in ['==', '<s', '<=s', '<u', '<=u'] and expr.cond.args[1].is_int() and expr.cond.args[0].is_op("+") and expr.cond.args[0].args[-1].is_int()): # X + int1 == int2 => X == int2-int1 left, right = expr.cond.args left, int_diff = left.args[:-1], left.args[-1] if len(left) == 1: left = left[0] else: left = ExprOp('+', *left) new_int = expr_simp(right - int_diff) expr = expr_simp(ExprCond(ExprOp(expr.cond.op, left, new_int), expr.src1, expr.src2)) return expr
def simp_cond_op_int(
e_s, expr)
Extract conditions from operations
def simp_cond_op_int(e_s, expr): "Extract conditions from operations" # x?a:b + x?c:d + e => x?(a+c+e:b+d+e) if not expr.op in ["+", "|", "^", "&", "*", '<<', '>>', 'a>>']: return expr if len(expr.args) < 2: return expr conds = set() for arg in expr.args: if arg.is_cond(): conds.add(arg) if len(conds) != 1: return expr cond = list(conds).pop() args1, args2 = [], [] for arg in expr.args: if arg.is_cond(): args1.append(arg.src1) args2.append(arg.src2) else: args1.append(arg) args2.append(arg) return ExprCond(cond.cond, ExprOp(expr.op, *args1), ExprOp(expr.op, *args2))
def simp_cst_propagation(
e_s, expr)
This passe includes: - Constant folding - Common logical identities - Common binary identities
def simp_cst_propagation(e_s, expr): """This passe includes: - Constant folding - Common logical identities - Common binary identities """ # merge associatif op args = list(expr.args) op_name = expr.op # simpl integer manip # int OP int => int # TODO: <<< >>> << >> are architecture dependant if op_name in op_propag_cst: while (len(args) >= 2 and args[-1].is_int() and args[-2].is_int()): int2 = args.pop() int1 = args.pop() if op_name == '+': out = int1.arg + int2.arg elif op_name == '*': out = int1.arg * int2.arg elif op_name == '**': out =int1.arg ** int2.arg elif op_name == '^': out = int1.arg ^ int2.arg elif op_name == '&': out = int1.arg & int2.arg elif op_name == '|': out = int1.arg | int2.arg elif op_name == '>>': if int(int2) > int1.size: out = 0 else: out = int1.arg >> int2.arg elif op_name == '<<': if int(int2) > int1.size: out = 0 else: out = int1.arg << int2.arg elif op_name == 'a>>': tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) if tmp2 > int1.size: is_signed = int(int1) & (1 << (int1.size - 1)) if is_signed: out = -1 else: out = 0 else: out = mod_size2uint[int1.arg.size](tmp1 >> tmp2) elif op_name == '>>>': shifter = int2.arg % int2.size out = (int1.arg >> shifter) | (int1.arg << (int2.size - shifter)) elif op_name == '<<<': shifter = int2.arg % int2.size out = (int1.arg << shifter) | (int1.arg >> (int2.size - shifter)) elif op_name == '/': out = int1.arg / int2.arg elif op_name == '%': out = int1.arg % int2.arg elif op_name == 'idiv': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) elif op_name == 'imod': assert int2.arg.arg tmp1 = mod_size2int[int1.arg.size](int1.arg) tmp2 = mod_size2int[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'umod': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 % tmp2) elif op_name == 'udiv': assert int2.arg.arg tmp1 = mod_size2uint[int1.arg.size](int1.arg) tmp2 = mod_size2uint[int2.arg.size](int2.arg) out = mod_size2uint[int1.arg.size](tmp1 / tmp2) args.append(ExprInt(out, int1.size)) # cnttrailzeros(int) => int if op_name == "cnttrailzeros" and args[0].is_int(): i = 0 while args[0].arg & (1 << i) == 0 and i < args[0].size: i += 1 return ExprInt(i, args[0].size) # cntleadzeros(int) => int if op_name == "cntleadzeros" and args[0].is_int(): if args[0].arg == 0: return ExprInt(args[0].size, args[0].size) i = args[0].size - 1 while args[0].arg & (1 << i) == 0: i -= 1 return ExprInt(expr.size - (i + 1), args[0].size) # -(-(A)) => A if (op_name == '-' and len(args) == 1 and args[0].is_op('-') and len(args[0].args) == 1): return args[0].args[0] # -(int) => -int if op_name == '-' and len(args) == 1 and args[0].is_int(): return ExprInt(-int(args[0]), expr.size) # A op 0 =>A if op_name in ['+', '|', "^", "<<", ">>", "<<<", ">>>"] and len(args) > 1: if args[-1].is_int(0): args.pop() # A - 0 =>A if op_name == '-' and len(args) > 1 and args[-1].is_int(0): assert len(args) == 2 # Op '-' with more than 2 args: SantityCheckError return args[0] # A * 1 =>A if op_name == "*" and len(args) > 1 and args[-1].is_int(1): args.pop() # for cannon form # A * -1 => - A if op_name == "*" and len(args) > 1 and args[-1] == args[-1].mask: args.pop() args[-1] = - args[-1] # op A => A if op_name in ['+', '*', '^', '&', '|', '>>', '<<', 'a>>', '<<<', '>>>', 'idiv', 'imod', 'umod', 'udiv'] and len(args) == 1: return args[0] # A-B => A + (-B) if op_name == '-' and len(args) > 1: if len(args) > 2: raise ValueError( 'sanity check fail on expr -: should have one or 2 args ' + '%r %s' % (expr, expr)) return ExprOp('+', args[0], -args[1]) # A op 0 => 0 if op_name in ['&', "*"] and args[-1].is_int(0): return ExprInt(0, expr.size) # - (A + B +...) => -A + -B + -C if op_name == '-' and len(args) == 1 and args[0].is_op('+'): args = [-a for a in args[0].args] return ExprOp('+', *args) # -(a?int1:int2) => (a?-int1:-int2) if (op_name == '-' and len(args) == 1 and args[0].is_cond() and args[0].src1.is_int() and args[0].src2.is_int()): int1 = args[0].src1 int2 = args[0].src2 int1 = ExprInt(-int1.arg, int1.size) int2 = ExprInt(-int2.arg, int2.size) return ExprCond(args[0].cond, int1, int2) i = 0 while i < len(args) - 1: j = i + 1 while j < len(args): # A ^ A => 0 if op_name == '^' and args[i] == args[j]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A + (- A) => 0 if op_name == '+' and args[j].is_op("-"): if len(args[j].args) == 1 and args[i] == args[j].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # (- A) + A => 0 if op_name == '+' and args[i].is_op("-"): if len(args[i].args) == 1 and args[j] == args[i].args[0]: args[i] = ExprInt(0, args[i].size) del args[j] continue # A | A => A if op_name == '|' and args[i] == args[j]: del args[j] continue # A & A => A if op_name == '&' and args[i] == args[j]: del args[j] continue j += 1 i += 1 if op_name in ['|', '&', '%', '/', '**'] and len(args) == 1: return args[0] # A <<< A.size => A if (op_name in ['<<<', '>>>'] and args[1].is_int() and args[1].arg == args[0].size): return args[0] # (A <<< X) <<< Y => A <<< (X+Y) (or <<< >>>) if X + Y does not overflow if (op_name in ['<<<', '>>>'] and args[0].is_op() and args[0].op in ['<<<', '>>>']): A = args[0].args[0] X = args[0].args[1] Y = args[1] if op_name != args[0].op and e_s(X - Y) == ExprInt(0, X.size): return args[0].args[0] elif X.is_int() and Y.is_int(): new_X = int(X) % expr.size new_Y = int(Y) % expr.size if op_name == args[0].op: rot = (new_X + new_Y) % expr.size op = op_name else: rot = new_Y - new_X op = op_name if rot < 0: rot = - rot op = {">>>": "<<<", "<<<": ">>>"}[op_name] args = [A, ExprInt(rot, expr.size)] op_name = op else: # Do not consider this case, too tricky (overflow on addition / # substraction) pass # A >> X >> Y => A >> (X+Y) if X + Y does not overflow # To be sure, only consider the simplification when X.msb and Y.msb are 0 if (op_name in ['<<', '>>'] and args[0].is_op(op_name)): X = args[0].args[1] Y = args[1] if (e_s(X.msb()) == ExprInt(0, 1) and e_s(Y.msb()) == ExprInt(0, 1)): args = [args[0].args[0], X + Y] # ((var >> int1) << int1) => var & mask # ((var << int1) >> int1) => var & mask if (op_name in ['<<', '>>'] and args[0].is_op() and args[0].op in ['<<', '>>'] and op_name != args[0]): var = args[0].args[0] int1 = args[0].args[1] int2 = args[1] if int1 == int2 and int1.is_int() and int(int1) < expr.size: if op_name == '>>': mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size) else: mask = ExprInt( ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1), expr.size ) ret = var & mask return ret # ((A & A.mask) if op_name == "&" and args[-1] == expr.mask: return ExprOp('&', *args[:-1]) # ((A | A.mask) if op_name == "|" and args[-1] == expr.mask: return args[-1] # ! (!X + int) => X - int # TODO # ((A & mask) >> shift) whith mask < 2**shift => 0 if op_name == ">>" and args[1].is_int() and args[0].is_op("&"): if (args[0].args[1].is_int() and 2 ** args[1].arg > args[0].args[1].arg): return ExprInt(0, args[0].size) # parity(int) => int if op_name == 'parity' and args[0].is_int(): return ExprInt(parity(int(args[0])), 1) # (-a) * b * (-c) * (-d) => (-a) * b * c * d if op_name == "*" and len(args) > 1: new_args = [] counter = 0 for arg in args: if arg.is_op('-') and len(arg.args) == 1: new_args.append(arg.args[0]) counter += 1 else: new_args.append(arg) if counter % 2: return -ExprOp(op_name, *new_args) args = new_args # A << int with A ExprCompose => move index if (op_name == "<<" and args[0].is_compose() and args[1].is_int() and int(args[1]) != 0): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index+shift, index+shift+arg.size)) # filter out expression filter_args = [] min_index = final_size for tmp, start, stop in new_args: if start >= final_size: continue if stop > final_size: tmp = tmp[:tmp.size - (stop - final_size)] stop = final_size filter_args.append(tmp) min_index = min(start, min_index) # create entry 0 assert min_index != 0 tmp = ExprInt(0, min_index) args = [tmp] + filter_args return ExprCompose(*args) # A >> int with A ExprCompose => move index if op_name == ">>" and args[0].is_compose() and args[1].is_int(): final_size = args[0].size shift = int(args[1]) new_args = [] # shift indexes for index, arg in args[0].iter_args(): new_args.append((arg, index-shift, index+arg.size-shift)) # filter out expression filter_args = [] max_index = 0 for tmp, start, stop in new_args: if stop <= 0: continue if start < 0: tmp = tmp[-start:] start = 0 filter_args.append(tmp) max_index = max(stop, max_index) # create entry 0 tmp = ExprInt(0, final_size - max_index) args = filter_args + [tmp] return ExprCompose(*args) # Compose(a) OP Compose(b) with a/b same bounds => Compose(a OP b) if op_name in ['|', '&', '^'] and all([arg.is_compose() for arg in args]): bounds = set() for arg in args: bound = tuple([tmp.size for tmp in arg.args]) bounds.add(bound) if len(bounds) == 1: bound = list(bounds)[0] new_args = [[tmp] for tmp in args[0].args] for sub_arg in args[1:]: for i, tmp in enumerate(sub_arg.args): new_args[i].append(tmp) args = [] for i, arg in enumerate(new_args): args.append(ExprOp(op_name, *arg)) return ExprCompose(*args) return ExprOp(op_name, *args)
def simp_mem(
e_s, expr)
Common simplifications on ExprMem
def simp_mem(e_s, expr): "Common simplifications on ExprMem" # @32[x?a:b] => x?@32[a]:@32[b] if expr.arg.is_cond(): cond = expr.arg ret = ExprCond(cond.cond, ExprMem(cond.src1, expr.size), ExprMem(cond.src2, expr.size)) return ret return expr
def simp_sign_subwc_cf(
expr_s, expr)
def simp_sign_subwc_cf(expr_s, expr): # SIGN_SUBWC(A, B, SUB_CF(C, D)) => SIGN_SUB({A, C}, {B, D}) if not expr.is_op('FLAG_SIGN_SUBWC'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SIGN_SUB", op1, op2)
def simp_slice(
e_s, expr)
Slice optimization
def simp_slice(e_s, expr): "Slice optimization" # slice(A, 0, a.size) => A if expr.start == 0 and expr.stop == expr.arg.size: return expr.arg # Slice(int) => int if expr.arg.is_int(): total_bit = expr.stop - expr.start mask = (1 << (expr.stop - expr.start)) - 1 return ExprInt(int((expr.arg.arg >> expr.start) & mask), total_bit) # Slice(Slice(A, x), y) => Slice(A, z) if expr.arg.is_slice(): if expr.stop - expr.start > expr.arg.stop - expr.arg.start: raise ValueError('slice in slice: getting more val', str(expr)) return ExprSlice(expr.arg.arg, expr.start + expr.arg.start, expr.start + expr.arg.start + (expr.stop - expr.start)) if expr.arg.is_compose(): # Slice(Compose(A), x) => Slice(A, y) for index, arg in expr.arg.iter_args(): if index <= expr.start and index+arg.size >= expr.stop: return arg[expr.start - index:expr.stop - index] # Slice(Compose(A, B, C), x) => Compose(A, B, C) with truncated A/B/C out = [] for index, arg in expr.arg.iter_args(): # arg is before slice start if expr.start >= index + arg.size: continue # arg is after slice stop elif expr.stop <= index: continue # arg is fully included in slice elif expr.start <= index and index + arg.size <= expr.stop: out.append(arg) continue # arg is truncated at start if expr.start > index: slice_start = expr.start - index else: # arg is not truncated at start slice_start = 0 # a is truncated at stop if expr.stop < index + arg.size: slice_stop = arg.size + expr.stop - (index + arg.size) - slice_start else: slice_stop = arg.size out.append(arg[slice_start:slice_stop]) return ExprCompose(*out) # ExprMem(x, size)[:A] => ExprMem(x, a) # XXXX todo hum, is it safe? if (expr.arg.is_mem() and expr.start == 0 and expr.arg.size > expr.stop and expr.stop % 8 == 0): return ExprMem(expr.arg.arg, size=expr.stop) # distributivity of slice and & # (a & int)[x:y] => 0 if int[x:y] == 0 if expr.arg.is_op("&") and expr.arg.args[-1].is_int(): tmp = e_s.expr_simp_wrapper(expr.arg.args[-1][expr.start:expr.stop]) if tmp.is_int(0): return tmp # distributivity of slice and exprcond # (a?int1:int2)[x:y] => (a?int1[x:y]:int2[x:y]) # (a?compose1:compose2)[x:y] => (a?compose1[x:y]:compose2[x:y]) if (expr.arg.is_cond() and (expr.arg.src1.is_int() or expr.arg.src1.is_compose()) and (expr.arg.src2.is_int() or expr.arg.src2.is_compose())): src1 = expr.arg.src1[expr.start:expr.stop] src2 = expr.arg.src2[expr.start:expr.stop] return ExprCond(expr.arg.cond, src1, src2) # (a * int)[0:y] => (a[0:y] * int[0:y]) if expr.start == 0 and expr.arg.is_op("*") and expr.arg.args[-1].is_int(): args = [e_s.expr_simp_wrapper(a[expr.start:expr.stop]) for a in expr.arg.args] return ExprOp(expr.arg.op, *args) # (a >> int)[x:y] => a[x+int:y+int] with int+y <= a.size # (a << int)[x:y] => a[x-int:y-int] with x-int >= 0 if (expr.arg.is_op() and expr.arg.op in [">>", "<<"] and expr.arg.args[1].is_int()): arg, shift = expr.arg.args shift = int(shift) if expr.arg.op == ">>": if shift + expr.stop <= arg.size: return arg[expr.start + shift:expr.stop + shift] elif expr.arg.op == "<<": if expr.start - shift >= 0: return arg[expr.start - shift:expr.stop - shift] else: raise ValueError('Bad case') return expr
def simp_subwc_cf(
expr_s, expr)
def simp_subwc_cf(expr_s, expr): # SUBWC_CF(A, B, SUB_CF(C, D)) => SUB_CF({A, C}, {B, D}) if not expr.is_op('FLAG_SUBWC_CF'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SUB_CF", op1, op2)
def simp_subwc_of(
expr_s, expr)
def simp_subwc_of(expr_s, expr): # SUBWC_OF(A, B, SUB_CF(C, D)) => SUB_OF({A, C}, {B, D}) if not expr.is_op('FLAG_SUBWC_OF'): return expr op3 = expr.args[2] if not op3.is_op("FLAG_SUB_CF"): return expr op1 = ExprCompose(expr.args[0], op3.args[0]) op2 = ExprCompose(expr.args[1], op3.args[1]) return ExprOp("FLAG_SUB_OF", op1, op2)
def simp_zeroext_eq_cst(
expr_s, expr)
def simp_zeroext_eq_cst(expr_s, expr): # A.zeroExt(X) == int => A == int[:A.size] if not expr.is_op("=="): return expr arg1, arg2 = expr.args if not arg2.is_int(): return expr if not (arg1.is_op() and arg1.op.startswith("zeroExt")): return expr src = arg1.args[0] if int(arg2) > (1 << src.size): # Always false return ExprInt(0, 1) return ExprOp("==", src, ExprInt(int(arg2), src.size))
def test_cc_eq_args(
expr, *sons_op)
def test_cc_eq_args(expr, *sons_op): if not expr.is_op(): return False if len(expr.args) != len(sons_op): return False all_args = set() for i, arg in enumerate(expr.args): if not arg.is_op(sons_op[i]): return False all_args.add(arg.args) return len(all_args) == 1