1 """
2 Functions for simplifying python math expressions.
3 """
4 from compiler.ast import *
5 import operator
6 import sets
7
8 import AST
9
10
11 _ZERO = Const(0)
12 _ONE = Const(1)
13
15 """
16 Return a simplified version of the expression.
17 """
18 ast = AST.strip_parse(expr)
19 return AST.ast2str(_simplify_ast(ast))
20
22 """
23 Return a simplified ast.
24
25 Current simplifications:
26 Special cases for zeros and ones, and combining of constants, in
27 addition, subtraction, multiplication, division.
28 Note that at present we only handle constants applied left to right.
29 1+1+x -> 2+x, but x+1+1 -> x+1+1.
30 x - x = 0
31 --x = x
32 """
33 if isinstance(ast, Name) or isinstance(ast, Const):
34 return ast
35 elif isinstance(ast, Add) or isinstance(ast, Sub):
36
37 pos, neg = [], []
38 AST._collect_pos_neg(ast, pos, neg)
39 pos = [_simplify_ast(term) for term in pos]
40 neg = [_simplify_ast(term) for term in neg]
41
42
43 values = [term.value for term in pos if isinstance(term, Const)] +\
44 [-term.value for term in neg if isinstance(term, Const)]
45 value = sum(values)
46
47
48 pos = [term for term in pos if not isinstance(term, Const)]
49 neg = [term for term in neg if not isinstance(term, Const)]
50
51 new_pos, new_neg = [], []
52 for term in pos:
53 if isinstance(term, UnarySub):
54 new_neg.append(term.expr)
55 else:
56 new_pos.append(term)
57 for term in neg:
58 if isinstance(term, UnarySub):
59 new_pos.append(term.expr)
60 else:
61 new_neg.append(term)
62 pos, neg = new_pos, new_neg
63
64
65 if value > 0:
66 pos.append(Const(value))
67 elif value < 0:
68 neg.append(Const(abs(value)))
69
70
71 term_counts = [(term, pos.count(term) - neg.count(term)) for term in
72 pos + neg]
73
74
75
76 term_counts = dict([(str(term), (term, count)) for term, count in
77 term_counts])
78
79
80 ii = 0
81 for ii, term in enumerate(pos+neg):
82 ast_out, count = term_counts[str(term)]
83 if count != 0:
84 break
85 else:
86
87
88 return _ZERO
89
90 term_counts[str(term)] = (ast_out, 0)
91 if abs(count) != 1:
92 ast_out = Mul((Const(abs(count)), ast_out))
93 if count < 0:
94 ast_out = UnarySub(ast_out)
95
96
97 for term in (pos+neg)[ii:]:
98 term, count = term_counts[str(term)]
99 term_counts[str(term)] = (term, 0)
100 if abs(count) != 1:
101 term = Mul((Const(abs(count)), term))
102 if count > 0:
103 ast_out = Add((ast_out, term))
104 elif count < 0:
105 ast_out = Sub((ast_out, term))
106
107 return ast_out
108 elif isinstance(ast, Mul) or isinstance(ast, Div):
109
110 num, denom = [], []
111 AST._collect_num_denom(ast, num, denom)
112 num = [_simplify_ast(term) for term in num]
113 denom = [_simplify_ast(term) for term in denom]
114
115
116 values = [term.value for term in num if isinstance(term, Const)] +\
117 [1./term.value for term in denom if isinstance(term, Const)]
118
119 value = reduce(operator.mul, values + [1])
120
121
122 if not value:
123 return _ZERO
124
125
126 num = [term for term in num if not isinstance(term, Const)]
127 denom = [term for term in denom if not isinstance(term, Const)]
128
129
130
131
132 num_neg = 0
133 for list_of_terms in [num, denom]:
134 for ii, term in enumerate(list_of_terms):
135 if isinstance(term, UnarySub):
136 list_of_terms[ii] = term.expr
137 num_neg += 1
138
139
140 if abs(value) != 1:
141 num.append(Const(abs(value)))
142 if value < 0:
143 num_neg += 1
144
145 make_neg = num_neg % 2
146
147
148 term_counts = [(term, num.count(term) - denom.count(term)) for term in
149 num + denom]
150
151
152
153 term_counts = dict([(str(term), (term, count)) for term, count in
154 term_counts])
155
156 nums, denoms = [], []
157
158
159 for term in num+denom:
160 term, count = term_counts[str(term)]
161
162
163 term_counts[str(term)] = (term, 0)
164 if abs(count) > 1:
165 term = Power((term, Const(abs(count))))
166 if count > 0:
167 nums.append(term)
168 elif count < 0:
169 denoms.append(term)
170
171
172
173 out = AST._make_product(nums)
174 if denoms:
175 denom = AST._make_product(denoms)
176 out = Div((out, denom))
177
178 if make_neg:
179 out = UnarySub(out)
180
181 return out
182 elif isinstance(ast, Power):
183
184
185 power = _simplify_ast(ast.right)
186 base = _simplify_ast(ast.left)
187
188 if power == _ZERO:
189
190
191 return _ONE
192 if base == _ZERO or base == _ONE or power == _ONE:
193 return base
194 elif isinstance(base, Const) and\
195 isinstance(power, Const):
196 return Const(base.value**power.value)
197
198
199 return Power((base, power))
200 elif isinstance(ast, UnarySub):
201 simple_expr = _simplify_ast(ast.expr)
202 if isinstance(simple_expr, UnarySub):
203
204 return _simplify_ast(simple_expr.expr)
205 elif isinstance(simple_expr, Const):
206 if simple_expr.value == 0:
207 return Const(0)
208 else:
209 return Const(-simple_expr.value)
210 else:
211 return UnarySub(simple_expr)
212 elif isinstance(ast, UnaryAdd):
213 simple_expr = _simplify_ast(ast.expr)
214 return simple_expr
215 elif isinstance(ast, list):
216 simple_list = [_simplify_ast(elem) for elem in ast]
217 return simple_list
218 elif isinstance(ast, tuple):
219 return tuple(_simplify_ast(list(ast)))
220 elif AST._node_attrs.has_key(ast.__class__):
221
222 for attr_name in AST._node_attrs[ast.__class__]:
223 attr = getattr(ast, attr_name)
224 if isinstance(attr, list):
225 for ii, elem in enumerate(attr):
226 attr[ii] = _simplify_ast(elem)
227 else:
228 setattr(ast, attr_name, _simplify_ast(attr))
229 return ast
230 else:
231 return ast
232