Package SloppyCell :: Package ExprManip :: Module Simplify
[hide private]

Source Code for Module SloppyCell.ExprManip.Simplify

  1  """ 
  2  Functions for simplifying python math expressions. 
  3  """ 
  4  from compiler.ast import * 
  5  import operator 
  6  import sets 
  7   
  8  import AST 
  9   
 10  # Constants for comparison 
 11  _ZERO = Const(0) 
 12  _ONE = Const(1) 
 13   
14 -def simplify_expr(expr):
15 """ 16 Return a simplified version of the expression. 17 """ 18 ast = AST.strip_parse(expr) 19 return AST.ast2str(_simplify_ast(ast))
20
21 -def _simplify_ast(ast):
22 """ 23 Return a simplified ast. 24 25 Current simplifications: 26 Special cases for zeros and ones, and combining of constants, in 27 addition, subtraction, multiplication, division. 28 Note that at present we only handle constants applied left to right. 29 1+1+x -> 2+x, but x+1+1 -> x+1+1. 30 x - x = 0 31 --x = x 32 """ 33 if isinstance(ast, Name) or isinstance(ast, Const): 34 return ast 35 elif isinstance(ast, Add) or isinstance(ast, Sub): 36 # We collect positive and negative terms and simplify each of them 37 pos, neg = [], [] 38 AST._collect_pos_neg(ast, pos, neg) 39 pos = [_simplify_ast(term) for term in pos] 40 neg = [_simplify_ast(term) for term in neg] 41 42 # We collect and sum the constant values 43 values = [term.value for term in pos if isinstance(term, Const)] +\ 44 [-term.value for term in neg if isinstance(term, Const)] 45 value = sum(values) 46 47 # Remove the constants from our pos and neg lists 48 pos = [term for term in pos if not isinstance(term, Const)] 49 neg = [term for term in neg if not isinstance(term, Const)] 50 51 new_pos, new_neg = [], [] 52 for term in pos: 53 if isinstance(term, UnarySub): 54 new_neg.append(term.expr) 55 else: 56 new_pos.append(term) 57 for term in neg: 58 if isinstance(term, UnarySub): 59 new_pos.append(term.expr) 60 else: 61 new_neg.append(term) 62 pos, neg = new_pos, new_neg 63 64 # Append the constant value sum to pos or neg 65 if value > 0: 66 pos.append(Const(value)) 67 elif value < 0: 68 neg.append(Const(abs(value))) 69 70 # Count the number of occurances of each term. 71 term_counts = [(term, pos.count(term) - neg.count(term)) for term in 72 pos + neg] 73 # Tricky: We use the str(term) as the key for the dictionary to ensure 74 # that each entry represents a unique term. We also drop terms 75 # that have a total count of 0. 76 term_counts = dict([(str(term), (term, count)) for term, count in 77 term_counts]) 78 79 # We find the first term with non-zero count. 80 ii = 0 81 for ii, term in enumerate(pos+neg): 82 ast_out, count = term_counts[str(term)] 83 if count != 0: 84 break 85 else: 86 # We get here if we don't break out of the loop, implying that 87 # all our terms had count of 0 88 return _ZERO 89 90 term_counts[str(term)] = (ast_out, 0) 91 if abs(count) != 1: 92 ast_out = Mul((Const(abs(count)), ast_out)) 93 if count < 0: 94 ast_out = UnarySub(ast_out) 95 96 # And add in all the rest 97 for term in (pos+neg)[ii:]: 98 term, count = term_counts[str(term)] 99 term_counts[str(term)] = (term, 0) 100 if abs(count) != 1: 101 term = Mul((Const(abs(count)), term)) 102 if count > 0: 103 ast_out = Add((ast_out, term)) 104 elif count < 0: 105 ast_out = Sub((ast_out, term)) 106 107 return ast_out 108 elif isinstance(ast, Mul) or isinstance(ast, Div): 109 # We collect numerator and denominator terms and simplify each of them 110 num, denom = [], [] 111 AST._collect_num_denom(ast, num, denom) 112 num = [_simplify_ast(term) for term in num] 113 denom = [_simplify_ast(term) for term in denom] 114 115 # We collect and sum the constant values 116 values = [term.value for term in num if isinstance(term, Const)] +\ 117 [1./term.value for term in denom if isinstance(term, Const)] 118 # This takes the product of all our values 119 value = reduce(operator.mul, values + [1]) 120 121 # If our value is 0, the expression is 0 122 if not value: 123 return _ZERO 124 125 # Remove the constants from our pos and neg lists 126 num = [term for term in num if not isinstance(term, Const)] 127 denom = [term for term in denom if not isinstance(term, Const)] 128 129 # Here we count all the negative (UnarySub) elements of our expression. 130 # We also remove the UnarySubs from their arguments. We'll correct 131 # for it at the end. 132 num_neg = 0 133 for list_of_terms in [num, denom]: 134 for ii, term in enumerate(list_of_terms): 135 if isinstance(term, UnarySub): 136 list_of_terms[ii] = term.expr 137 num_neg += 1 138 139 # Append the constant value sum to pos or neg 140 if abs(value) != 1: 141 num.append(Const(abs(value))) 142 if value < 0: 143 num_neg += 1 144 145 make_neg = num_neg % 2 146 147 # Count the number of occurances of each term. 148 term_counts = [(term, num.count(term) - denom.count(term)) for term in 149 num + denom] 150 # Tricky: We use the str(term) as the key for the dictionary to ensure 151 # that each entry represents a unique term. We also drop terms 152 # that have a total count of 0. 153 term_counts = dict([(str(term), (term, count)) for term, count in 154 term_counts]) 155 156 nums, denoms = [], [] 157 # We walk through terms in num+denom in order, so we rearrange a little 158 # as possible. 159 for term in num+denom: 160 term, count = term_counts[str(term)] 161 # Once a term has been done, we set its term_counts to 0, so it 162 # doesn't get done again. 163 term_counts[str(term)] = (term, 0) 164 if abs(count) > 1: 165 term = Power((term, Const(abs(count)))) 166 if count > 0: 167 nums.append(term) 168 elif count < 0: 169 denoms.append(term) 170 171 # We return the product of the numerator terms over the product of the 172 # denominator terms 173 out = AST._make_product(nums) 174 if denoms: 175 denom = AST._make_product(denoms) 176 out = Div((out, denom)) 177 178 if make_neg: 179 out = UnarySub(out) 180 181 return out 182 elif isinstance(ast, Power): 183 # These cases all have a left and a right, so we group them just to 184 # avoid some code duplication. 185 power = _simplify_ast(ast.right) 186 base = _simplify_ast(ast.left) 187 188 if power == _ZERO: 189 # Anything, including 0, to the 0th power is 1, so this 190 # test should come first 191 return _ONE 192 if base == _ZERO or base == _ONE or power == _ONE: 193 return base 194 elif isinstance(base, Const) and\ 195 isinstance(power, Const): 196 return Const(base.value**power.value) 197 # Getting here implies that no simplifications are possible, so just 198 # return with simplified arguments 199 return Power((base, power)) 200 elif isinstance(ast, UnarySub): 201 simple_expr = _simplify_ast(ast.expr) 202 if isinstance(simple_expr, UnarySub): 203 # Case --x 204 return _simplify_ast(simple_expr.expr) 205 elif isinstance(simple_expr, Const): 206 if simple_expr.value == 0: 207 return Const(0) 208 else: 209 return Const(-simple_expr.value) 210 else: 211 return UnarySub(simple_expr) 212 elif isinstance(ast, UnaryAdd): 213 simple_expr = _simplify_ast(ast.expr) 214 return simple_expr 215 elif isinstance(ast, list): 216 simple_list = [_simplify_ast(elem) for elem in ast] 217 return simple_list 218 elif isinstance(ast, tuple): 219 return tuple(_simplify_ast(list(ast))) 220 elif AST._node_attrs.has_key(ast.__class__): 221 # Handle node types with no special cases. 222 for attr_name in AST._node_attrs[ast.__class__]: 223 attr = getattr(ast, attr_name) 224 if isinstance(attr, list): 225 for ii, elem in enumerate(attr): 226 attr[ii] = _simplify_ast(elem) 227 else: 228 setattr(ast, attr_name, _simplify_ast(attr)) 229 return ast 230 else: 231 return ast
232