Package SloppyCell :: Package ExprManip :: Module Differentiation
[hide private]

Source Code for Module SloppyCell.ExprManip.Differentiation

  1  from compiler.ast import * 
  2  import copy 
  3  import cPickle 
  4  import logging 
  5  import os 
  6  logger = logging.getLogger('ExprManip.Differentiation') 
  7   
  8  import AST 
  9  from AST import strip_parse 
 10  import Simplify 
 11  import Substitution 
 12   
 13  _ZERO = Const(0) 
 14  _ONE = Const(1) 
 15   
 16  __deriv_saved = {} 
17 -def load_derivs(filename):
18 """ 19 Load up a pickled dictionary of saved derivatives. 20 """ 21 global __deriv_saved 22 try: 23 f = file(filename, 'rb') 24 except IOError: 25 # This failure probably indicates that the file doesn't exist. 26 return 27 try: 28 __deriv_saved = cPickle.load(f) 29 return 30 except: 31 # For some reason, pulling the data from the file failed. 32 logger.warn('Failed to load saved derivative file %s. Trying to delete it to ' 33 'avoid future problems.' % filename) 34 try: 35 # Let's try to remove it... 36 f.close() 37 os.remove(filename) 38 except: 39 logger.warn('File removal failed. Please delete %s manually.' % filename)
40
41 -def save_derivs(filename):
42 f = file(filename, 'wb') 43 cPickle.dump(__deriv_saved, f, 2) 44 f.close()
45
46 -def diff_expr(expr, wrt):
47 """ 48 Return the derivative of the expression with respect to a given variable. 49 """ 50 logger.debug('Taking derivative of %s wrt %s' % (expr, wrt)) 51 key = '%s__derivWRT__%s' % (expr, wrt) 52 if __deriv_saved.has_key(key): 53 deriv = __deriv_saved[key] 54 logger.debug('Found saved result %s.' % deriv) 55 return deriv 56 57 ast = AST.strip_parse(expr) 58 deriv = _diff_ast(ast, wrt) 59 deriv = Simplify._simplify_ast(deriv) 60 deriv = AST.ast2str(deriv) 61 __deriv_saved[key] = deriv 62 logger.debug('Computed result %s.' % deriv) 63 return deriv
64 65 # This dictionary stores how to differentiate various functions. The keys are 66 # (function name, # of arguments). The values are tuples of strings which 67 # give the partial derivive wrt each argument. It is important to use arg# to 68 # denote the arguments. 69 _KNOWN_FUNCS = {('acos', 1): ('1/sqrt(1-arg0**2)',), 70 ('asin', 1): ('1/sqrt(1-arg0**2)',), 71 ('atan', 1): ('1/(1+arg0**2)',), 72 ('cos', 1): ('-sin(arg0)',), 73 ('cosh', 1): ('sinh(arg0)',), 74 ('exp', 1): ('exp(arg0)',), 75 ('log', 1): ('1/arg0',), 76 ('log10', 1): ('1/(log(10)*arg0)',), 77 ('sin', 1): ('cos(arg0)',), 78 ('sinh', 1): ('cosh(arg0)',), 79 ('arcsinh', 1): ('1/sqrt(1+arg0**2)',), 80 ('arccosh', 1): ('1/sqrt(arg0**2 - 1.)',), 81 ('arctanh', 1): ('1/(1.-arg0**2)',), 82 ('sqrt', 1): ('1/(2*sqrt(arg0))',), 83 ('tan', 1): ('1/cos(arg0)**2',), 84 ('tanh', 1): ('1/cosh(arg0)**2',), 85 ('pow', 2): ('arg1 * arg0**(arg1-1)', 86 'log(arg0) * arg0**arg1') 87 } 88 for key, terms in _KNOWN_FUNCS.items(): 89 _KNOWN_FUNCS[key] = [strip_parse(term) for term in terms] 90
91 -def _diff_ast(ast, wrt):
92 """ 93 Return an AST that is the derivative of ast with respect the variable with 94 name 'wrt'. 95 """ 96 97 # For now, the strategy is to return the most general forms, and let 98 # the simplifier take care of the special cases. 99 if isinstance(ast, Name): 100 if ast.name == wrt: 101 return _ONE 102 else: 103 return _ZERO 104 elif isinstance(ast, Const): 105 return _ZERO 106 elif isinstance(ast, Add) or isinstance(ast, Sub): 107 # Just take the derivative of the arguments. The call to ast.__class__ 108 # lets us use the same code from Add and Sub. 109 return ast.__class__((_diff_ast(ast.left, wrt), 110 _diff_ast(ast.right, wrt))) 111 elif isinstance(ast, Mul) or isinstance(ast, Div): 112 # Collect all the numerators and denominators together 113 nums, denoms = [], [] 114 AST._collect_num_denom(ast, nums, denoms) 115 116 # Collect the numerator terms into a single AST 117 num = AST._make_product(nums) 118 # Take the derivative of the numerator terms as a product 119 num_d = _product_deriv(nums, wrt) 120 if not denoms: 121 # If there is no denominator 122 return num_d 123 124 denom = AST._make_product(denoms) 125 denom_d = _product_deriv(denoms, wrt) 126 127 # Derivative of x/y is x'/y + -x*y'/y**2 128 term1 = Div((num_d, denom)) 129 term2 = Div((Mul((UnarySub(num), denom_d)), Power((denom, Const(2))))) 130 return Add((term1, term2)) 131 132 elif isinstance(ast, Power): 133 # Use the derivative of the 'pow' function 134 ast = CallFunc(Name('pow'), [ast.left, ast.right]) 135 return _diff_ast(ast, wrt) 136 137 elif isinstance(ast, CallFunc): 138 func_name = AST.ast2str(ast.node) 139 args = ast.args 140 args_d = [_diff_ast(arg, wrt) for arg in args] 141 142 if _KNOWN_FUNCS.has_key((func_name, len(args))): 143 form = copy.deepcopy(_KNOWN_FUNCS[(func_name, len(args))]) 144 else: 145 # If this isn't a known function, our form is 146 # (f_0(args), f_1(args), ...) 147 args_expr = [Name('arg%i' % ii) for ii in range(len(args))] 148 form = [CallFunc(Name('%s_%i' % (func_name, ii)), args_expr) for 149 ii in range(len(args))] 150 151 # We build up the terms in our derivative 152 # f_0(x,y)*x' + f_1(x,y)*y', etc. 153 outs = [] 154 for arg_d, arg_form_d in zip(args_d, form): 155 # We skip arguments with 0 derivative 156 if arg_d == _ZERO: 157 continue 158 for ii, arg in enumerate(args): 159 Substitution._sub_subtrees_for_vars(arg_form_d, 160 {'arg%i'%ii:arg}) 161 outs.append(Mul((arg_form_d, arg_d))) 162 163 # If all arguments had zero deriviative 164 if not outs: 165 return _ZERO 166 else: 167 # We add up all our terms 168 ret = outs[0] 169 for term in outs[1:]: 170 ret = Add((ret, term)) 171 return ret 172 173 elif isinstance(ast, UnarySub): 174 return UnarySub(_diff_ast(ast.expr, wrt)) 175 176 elif isinstance(ast, UnaryAdd): 177 return UnaryAdd(_diff_ast(ast.expr, wrt))
178
179 -def _product_deriv(terms, wrt):
180 """ 181 Return an AST expressing the derivative of the product of all the terms. 182 """ 183 if len(terms) == 1: 184 return _diff_ast(terms[0], wrt) 185 deriv_terms = [] 186 for ii, term in enumerate(terms): 187 term_d = _diff_ast(term, wrt) 188 other_terms = terms[:ii] + terms[ii+1:] 189 deriv_terms.append(AST._make_product(other_terms + [term_d])) 190 sum = deriv_terms[0] 191 for term in deriv_terms[1:]: 192 sum = Add((term, sum)) 193 194 return sum
195