1 import compiler
2 from compiler.ast import *
3
4 TINY = 1e-12
5
7 """
8 Return whether self and other represent the same expressions.
9
10 Unfortunately, the Node class in Python 2.3 doesn't define ==, so
11 we need to write our own.
12 """
13
14 if not isinstance(other, Node) or not isinstance(other, self.__class__):
15 return False
16
17 for self_child, other_child in zip(self.getChildren(), other.getChildren()):
18 if not self_child == other_child:
19 return False
20
21 return True
22
23 Node.__eq__ = _node_equal
24
26 """
27 Return an abstract syntax tree (AST) for an expression.
28
29 This removes the enclosing cruft from a call to compiler.parse(expr)
30 """
31
32
33 ast = compiler.parse(str(expr).strip())
34 return ast.node.nodes[0].expr
35
36
37
38 _OP_ORDER = {Name: 0,
39 Const: 0,
40 CallFunc: 0,
41 Subscript: 0,
42 Slice: 0,
43 Sliceobj: 0,
44 Power: 3,
45 UnarySub: 4,
46 UnaryAdd: 4,
47 Mul: 5,
48 Div: 5,
49 Sub: 10,
50 Add: 10,
51 Compare: 11,
52 Not: 11,
53 And: 11,
54 Or: 11,
55 Discard: 100}
56
57
58 _FARTHEST_OUT = Discard(None)
59
60
61 _node_attrs = {Name: (),
62 Const: (),
63 Add: ('left', 'right'),
64 Sub: ('left', 'right'),
65 Mul: ('left', 'right'),
66 Div: ('left', 'right'),
67 CallFunc: ('args',),
68 Power: ('left', 'right'),
69 UnarySub: ('expr',),
70 UnaryAdd: ('expr',),
71 Slice: ('lower', 'upper'),
72 Sliceobj: ('nodes',),
73 Subscript: ('subs',),
74 Compare: ('expr', 'ops'),
75 Not: ('expr',),
76 Or: ('nodes',),
77 And: ('nodes',),
78 }
79
80
82 """
83 Return the string representation of an AST.
84
85 outer: The AST's 'parent' node, used to determine whether or not to
86 enclose the result in parentheses. The default of _FARTHEST_OUT will
87 never enclose the result in parentheses.
88
89 adjust: A numerical value to adjust the priority of this ast for
90 particular cases. For example, the denominator of a '/' needs
91 parentheses in more cases than does the numerator.
92 """
93 if isinstance(ast, Name):
94 out = ast.name
95 elif isinstance(ast, Const):
96 out = str(ast.value)
97 elif isinstance(ast, Add):
98 out = '%s + %s' % (ast2str(ast.left, ast),
99 ast2str(ast.right, ast))
100 elif isinstance(ast, Sub):
101 out = '%s - %s' % (ast2str(ast.left, ast),
102 ast2str(ast.right, ast, adjust = TINY))
103 elif isinstance(ast, Mul):
104 out = '%s*%s' % (ast2str(ast.left, ast),
105 ast2str(ast.right, ast))
106 elif isinstance(ast, Div):
107
108 out = '%s/%s' % (ast2str(ast.left, ast),
109 ast2str(ast.right, ast, adjust = TINY))
110 elif isinstance(ast, Power):
111
112 out = '%s**%s' % (ast2str(ast.left, ast, adjust = TINY),
113 ast2str(ast.right, ast))
114 elif isinstance(ast, UnarySub):
115 out = '-%s' % ast2str(ast.expr, ast)
116 elif isinstance(ast, UnaryAdd):
117 out = '+%s' % ast2str(ast.expr, ast)
118 elif isinstance(ast, CallFunc):
119 args = [ast2str(arg) for arg in ast.args]
120 out = '%s(%s)' % (ast2str(ast.node), ', '.join(args))
121 elif isinstance(ast, Subscript):
122 subs = [ast2str(sub) for sub in ast.subs]
123 out = '%s[%s]' % (ast2str(ast.expr), ', '.join(subs))
124 elif isinstance(ast, Slice):
125 out = '%s[%s:%s]' % (ast2str(ast.expr), ast2str(ast.lower),
126 ast2str(ast.upper))
127 elif isinstance(ast, Sliceobj):
128 nodes = [ast2str(node) for node in ast.nodes]
129 out = ':'.join(nodes)
130 elif isinstance(ast, Compare):
131 expr = ast2str(ast.expr, ast, adjust=6+TINY)
132 out_l = [expr]
133 for op, val in ast.ops:
134 out_l.append(op)
135 out_l.append(ast2str(val, ast, adjust=6+TINY))
136 out = ' '.join(out_l)
137 elif isinstance(ast, And):
138 nodes = [ast2str(node, ast, adjust=TINY) for node in ast.nodes]
139 out = ' and '.join(nodes)
140 elif isinstance(ast, Or):
141 nodes = [ast2str(node, ast, adjust=TINY) for node in ast.nodes]
142 out = ' or '.join(nodes)
143 elif isinstance(ast, Not):
144 out = 'not %s' % ast2str(ast.expr, ast, adjust=TINY)
145
146
147 if _need_parens(outer, ast, adjust):
148 return out
149 else:
150 return '(%s)' % out
151
153 """
154 Return whether or not the inner AST needs parentheses when enclosed by the
155 outer.
156
157 adjust: A numerical value to adjust the priority of this ast for
158 particular cases. For example, the denominator of a '/' needs
159 parentheses in more cases than does the numerator.
160 """
161 return _OP_ORDER[outer.__class__] >= _OP_ORDER[inner.__class__] + adjust
162
164 """
165 Append to nums and denoms, respectively, the nodes in the numerator and
166 denominator of an AST.
167 """
168 if not (isinstance(ast, Mul) or isinstance(ast, Div)):
169
170 nums.append(ast)
171 return
172
173 if isinstance(ast.left, Div) or isinstance(ast.left, Mul):
174
175
176 _collect_num_denom(ast.left, nums, denoms)
177 else:
178 nums.append(ast.left)
179
180 if isinstance(ast.right, Div) or isinstance(ast.right, Mul):
181
182
183 if isinstance(ast, Mul):
184 _collect_num_denom(ast.right, nums, denoms)
185 elif isinstance(ast, Div):
186
187
188 _collect_num_denom(ast.right, denoms, nums)
189 else:
190 if isinstance(ast, Mul):
191 nums.append(ast.right)
192 elif isinstance(ast, Div):
193 denoms.append(ast.right)
194
196 """
197 Append to poss and negs, respectively, the nodes in AST with positive and
198 negative factors from a addition/subtraction chain.
199 """
200
201
202
203
204 if not (isinstance(ast, Add) or isinstance(ast, Sub)):
205 poss.append(ast)
206 return
207
208 if isinstance(ast.left, Sub) or isinstance(ast.left, Add):
209 _collect_pos_neg(ast.left, poss, negs)
210 else:
211 poss.append(ast.left)
212
213 if isinstance(ast.right, Sub) or isinstance(ast.right, Add):
214 if isinstance(ast, Add):
215 _collect_pos_neg(ast.right, poss, negs)
216 elif isinstance(ast, Sub):
217 _collect_pos_neg(ast.right, negs, poss)
218 else:
219 if isinstance(ast, Add):
220 poss.append(ast.right)
221 elif isinstance(ast, Sub):
222 negs.append(ast.right)
223
225 """
226 Return an AST expressing the product of all the terms.
227 """
228 if terms:
229 product = terms[0]
230 for term in terms[1:]:
231 product = Mul((product, term))
232 return product
233 else:
234 return Const(1)
235
237 if isinstance(ast, list):
238 for ii, elem in enumerate(ast):
239 ast[ii] = func(elem, *args)
240 elif isinstance(ast, tuple):
241 ast = tuple(func(list(ast), *args))
242 elif _node_attrs.has_key(ast.__class__):
243 for attr_name in _node_attrs[ast.__class__]:
244 attr = getattr(ast, attr_name)
245 attr_mod = func(attr, *args)
246 setattr(ast, attr_name, attr_mod)
247
248 return ast
249