1  from compiler.ast import * 
  2  import copy 
  3  import cPickle 
  4  import logging 
  5  import os 
  6  logger = logging.getLogger('ExprManip.Differentiation') 
  7   
  8  import AST 
  9  from AST import strip_parse 
 10  import Simplify 
 11  import Substitution 
 12   
 13  _ZERO = Const(0) 
 14  _ONE = Const(1) 
 15   
 16  __deriv_saved = {} 
 18      """ 
 19      Load up a pickled dictionary of saved derivatives. 
 20      """ 
 21      global __deriv_saved 
 22      try: 
 23          f = file(filename, 'rb') 
 24      except IOError: 
 25           
 26          return 
 27      try: 
 28          __deriv_saved = cPickle.load(f) 
 29          return 
 30      except: 
 31           
 32          logger.warn('Failed to load saved derivative file %s. Trying to delete it to ' 
 33                      'avoid future problems.' % filename) 
 34          try: 
 35               
 36              f.close() 
 37              os.remove(filename) 
 38          except: 
 39              logger.warn('File removal failed. Please delete %s manually.' % filename) 
  40   
 42      f = file(filename, 'wb') 
 43      cPickle.dump(__deriv_saved, f, 2) 
 44      f.close() 
  45   
 47      """ 
 48      Return the derivative of the expression with respect to a given variable. 
 49      """ 
 50      logger.debug('Taking derivative of %s wrt %s' % (expr, wrt)) 
 51      key = '%s__derivWRT__%s' % (expr, wrt) 
 52      if __deriv_saved.has_key(key): 
 53          deriv = __deriv_saved[key] 
 54          logger.debug('Found saved result %s.' % deriv) 
 55          return deriv 
 56   
 57      ast = AST.strip_parse(expr) 
 58      deriv = _diff_ast(ast, wrt) 
 59      deriv = Simplify._simplify_ast(deriv) 
 60      deriv  = AST.ast2str(deriv) 
 61      __deriv_saved[key] = deriv 
 62      logger.debug('Computed result %s.' % deriv) 
 63      return deriv 
  64   
 65   
 66   
 67   
 68   
 69  _KNOWN_FUNCS = {('acos', 1): ('1/sqrt(1-arg0**2)',), 
 70                  ('asin', 1): ('1/sqrt(1-arg0**2)',), 
 71                  ('atan', 1): ('1/(1+arg0**2)',), 
 72                  ('cos', 1): ('-sin(arg0)',), 
 73                  ('cosh', 1): ('sinh(arg0)',), 
 74                  ('exp', 1): ('exp(arg0)',), 
 75                  ('log', 1): ('1/arg0',), 
 76                  ('log10', 1): ('1/(log(10)*arg0)',), 
 77                  ('sin', 1): ('cos(arg0)',), 
 78                  ('sinh', 1): ('cosh(arg0)',), 
 79                  ('arcsinh', 1): ('1/sqrt(1+arg0**2)',), 
 80                  ('arccosh', 1): ('1/sqrt(arg0**2 - 1.)',), 
 81                  ('arctanh', 1): ('1/(1.-arg0**2)',), 
 82                  ('sqrt', 1): ('1/(2*sqrt(arg0))',), 
 83                  ('tan', 1): ('1/cos(arg0)**2',), 
 84                  ('tanh', 1): ('1/cosh(arg0)**2',), 
 85                  ('pow', 2): ('arg1 * arg0**(arg1-1)',  
 86                               'log(arg0) * arg0**arg1') 
 87                  } 
 88  for key, terms in _KNOWN_FUNCS.items(): 
 89      _KNOWN_FUNCS[key] = [strip_parse(term) for term in terms] 
 90   
 92      """ 
 93      Return an AST that is the derivative of ast with respect the variable with 
 94      name 'wrt'. 
 95      """ 
 96   
 97       
 98       
 99      if isinstance(ast, Name): 
100          if ast.name == wrt: 
101              return _ONE 
102          else: 
103              return _ZERO 
104      elif isinstance(ast, Const): 
105          return _ZERO 
106      elif isinstance(ast, Add) or isinstance(ast, Sub): 
107           
108           
109          return ast.__class__((_diff_ast(ast.left, wrt),  
110                                _diff_ast(ast.right, wrt))) 
111      elif isinstance(ast, Mul) or isinstance(ast, Div): 
112           
113          nums, denoms = [], [] 
114          AST._collect_num_denom(ast, nums, denoms) 
115   
116           
117          num = AST._make_product(nums) 
118           
119          num_d = _product_deriv(nums, wrt) 
120          if not denoms: 
121               
122              return num_d 
123   
124          denom = AST._make_product(denoms) 
125          denom_d = _product_deriv(denoms, wrt) 
126   
127           
128          term1 = Div((num_d, denom)) 
129          term2 = Div((Mul((UnarySub(num), denom_d)), Power((denom, Const(2))))) 
130          return Add((term1, term2)) 
131   
132      elif isinstance(ast, Power): 
133           
134          ast =  CallFunc(Name('pow'), [ast.left, ast.right]) 
135          return _diff_ast(ast, wrt) 
136   
137      elif isinstance(ast, CallFunc): 
138          func_name = AST.ast2str(ast.node) 
139          args = ast.args 
140          args_d = [_diff_ast(arg, wrt) for arg in args] 
141   
142          if _KNOWN_FUNCS.has_key((func_name, len(args))): 
143              form = copy.deepcopy(_KNOWN_FUNCS[(func_name, len(args))]) 
144          else: 
145               
146               
147              args_expr = [Name('arg%i' % ii) for ii in range(len(args))] 
148              form = [CallFunc(Name('%s_%i' % (func_name, ii)), args_expr) for 
149                      ii in range(len(args))] 
150   
151           
152           
153          outs = [] 
154          for arg_d, arg_form_d in zip(args_d, form): 
155               
156              if arg_d == _ZERO: 
157                  continue 
158              for ii, arg in enumerate(args): 
159                  Substitution._sub_subtrees_for_vars(arg_form_d,  
160                                                      {'arg%i'%ii:arg}) 
161              outs.append(Mul((arg_form_d, arg_d))) 
162   
163           
164          if not outs: 
165              return _ZERO 
166          else: 
167               
168              ret = outs[0] 
169              for term in outs[1:]: 
170                  ret = Add((ret, term)) 
171              return ret 
172   
173      elif isinstance(ast, UnarySub): 
174          return UnarySub(_diff_ast(ast.expr, wrt)) 
175   
176      elif isinstance(ast, UnaryAdd): 
177          return UnaryAdd(_diff_ast(ast.expr, wrt)) 
 178   
180      """ 
181      Return an AST expressing the derivative of the product of all the terms. 
182      """ 
183      if len(terms) == 1: 
184          return _diff_ast(terms[0], wrt) 
185      deriv_terms = [] 
186      for ii, term in enumerate(terms): 
187          term_d = _diff_ast(term, wrt) 
188          other_terms = terms[:ii] + terms[ii+1:] 
189          deriv_terms.append(AST._make_product(other_terms + [term_d])) 
190      sum = deriv_terms[0] 
191      for term in deriv_terms[1:]: 
192          sum = Add((term, sum)) 
193   
194      return sum 
 195