1 from compiler.ast import *
2 import copy
3 import cPickle
4 import logging
5 import os
6 logger = logging.getLogger('ExprManip.Differentiation')
7
8 import AST
9 from AST import strip_parse
10 import Simplify
11 import Substitution
12
13 _ZERO = Const(0)
14 _ONE = Const(1)
15
16 __deriv_saved = {}
18 """
19 Load up a pickled dictionary of saved derivatives.
20 """
21 global __deriv_saved
22 try:
23 f = file(filename, 'rb')
24 except IOError:
25
26 return
27 try:
28 __deriv_saved = cPickle.load(f)
29 return
30 except:
31
32 logger.warn('Failed to load saved derivative file %s. Trying to delete it to '
33 'avoid future problems.' % filename)
34 try:
35
36 f.close()
37 os.remove(filename)
38 except:
39 logger.warn('File removal failed. Please delete %s manually.' % filename)
40
42 f = file(filename, 'wb')
43 cPickle.dump(__deriv_saved, f, 2)
44 f.close()
45
47 """
48 Return the derivative of the expression with respect to a given variable.
49 """
50 logger.debug('Taking derivative of %s wrt %s' % (expr, wrt))
51 key = '%s__derivWRT__%s' % (expr, wrt)
52 if __deriv_saved.has_key(key):
53 deriv = __deriv_saved[key]
54 logger.debug('Found saved result %s.' % deriv)
55 return deriv
56
57 ast = AST.strip_parse(expr)
58 deriv = _diff_ast(ast, wrt)
59 deriv = Simplify._simplify_ast(deriv)
60 deriv = AST.ast2str(deriv)
61 __deriv_saved[key] = deriv
62 logger.debug('Computed result %s.' % deriv)
63 return deriv
64
65
66
67
68
69 _KNOWN_FUNCS = {('acos', 1): ('1/sqrt(1-arg0**2)',),
70 ('asin', 1): ('1/sqrt(1-arg0**2)',),
71 ('atan', 1): ('1/(1+arg0**2)',),
72 ('cos', 1): ('-sin(arg0)',),
73 ('cosh', 1): ('sinh(arg0)',),
74 ('exp', 1): ('exp(arg0)',),
75 ('log', 1): ('1/arg0',),
76 ('log10', 1): ('1/(log(10)*arg0)',),
77 ('sin', 1): ('cos(arg0)',),
78 ('sinh', 1): ('cosh(arg0)',),
79 ('arcsinh', 1): ('1/sqrt(1+arg0**2)',),
80 ('arccosh', 1): ('1/sqrt(arg0**2 - 1.)',),
81 ('arctanh', 1): ('1/(1.-arg0**2)',),
82 ('sqrt', 1): ('1/(2*sqrt(arg0))',),
83 ('tan', 1): ('1/cos(arg0)**2',),
84 ('tanh', 1): ('1/cosh(arg0)**2',),
85 ('pow', 2): ('arg1 * arg0**(arg1-1)',
86 'log(arg0) * arg0**arg1')
87 }
88 for key, terms in _KNOWN_FUNCS.items():
89 _KNOWN_FUNCS[key] = [strip_parse(term) for term in terms]
90
92 """
93 Return an AST that is the derivative of ast with respect the variable with
94 name 'wrt'.
95 """
96
97
98
99 if isinstance(ast, Name):
100 if ast.name == wrt:
101 return _ONE
102 else:
103 return _ZERO
104 elif isinstance(ast, Const):
105 return _ZERO
106 elif isinstance(ast, Add) or isinstance(ast, Sub):
107
108
109 return ast.__class__((_diff_ast(ast.left, wrt),
110 _diff_ast(ast.right, wrt)))
111 elif isinstance(ast, Mul) or isinstance(ast, Div):
112
113 nums, denoms = [], []
114 AST._collect_num_denom(ast, nums, denoms)
115
116
117 num = AST._make_product(nums)
118
119 num_d = _product_deriv(nums, wrt)
120 if not denoms:
121
122 return num_d
123
124 denom = AST._make_product(denoms)
125 denom_d = _product_deriv(denoms, wrt)
126
127
128 term1 = Div((num_d, denom))
129 term2 = Div((Mul((UnarySub(num), denom_d)), Power((denom, Const(2)))))
130 return Add((term1, term2))
131
132 elif isinstance(ast, Power):
133
134 ast = CallFunc(Name('pow'), [ast.left, ast.right])
135 return _diff_ast(ast, wrt)
136
137 elif isinstance(ast, CallFunc):
138 func_name = AST.ast2str(ast.node)
139 args = ast.args
140 args_d = [_diff_ast(arg, wrt) for arg in args]
141
142 if _KNOWN_FUNCS.has_key((func_name, len(args))):
143 form = copy.deepcopy(_KNOWN_FUNCS[(func_name, len(args))])
144 else:
145
146
147 args_expr = [Name('arg%i' % ii) for ii in range(len(args))]
148 form = [CallFunc(Name('%s_%i' % (func_name, ii)), args_expr) for
149 ii in range(len(args))]
150
151
152
153 outs = []
154 for arg_d, arg_form_d in zip(args_d, form):
155
156 if arg_d == _ZERO:
157 continue
158 for ii, arg in enumerate(args):
159 Substitution._sub_subtrees_for_vars(arg_form_d,
160 {'arg%i'%ii:arg})
161 outs.append(Mul((arg_form_d, arg_d)))
162
163
164 if not outs:
165 return _ZERO
166 else:
167
168 ret = outs[0]
169 for term in outs[1:]:
170 ret = Add((ret, term))
171 return ret
172
173 elif isinstance(ast, UnarySub):
174 return UnarySub(_diff_ast(ast.expr, wrt))
175
176 elif isinstance(ast, UnaryAdd):
177 return UnaryAdd(_diff_ast(ast.expr, wrt))
178
180 """
181 Return an AST expressing the derivative of the product of all the terms.
182 """
183 if len(terms) == 1:
184 return _diff_ast(terms[0], wrt)
185 deriv_terms = []
186 for ii, term in enumerate(terms):
187 term_d = _diff_ast(term, wrt)
188 other_terms = terms[:ii] + terms[ii+1:]
189 deriv_terms.append(AST._make_product(other_terms + [term_d]))
190 sum = deriv_terms[0]
191 for term in deriv_terms[1:]:
192 sum = Add((term, sum))
193
194 return sum
195