Tracing Functions in Python

19/11/2015

Lets say we have some function in Python for which we want to do some kind of symbolic manipulation. For example we might have some mathematical function which we wish to automatically calculate the derivative of.

def f(x):
    return 5*x + 10

Once the function is defined in Python it might seem like it's impossible to do anything to the symbolic representation without doing some difficult things such as manipulating the bytecode or writing some kind of preprocessor.

But there are a couple of cute tricks we can use to extract a symbolic representation of the function. They have some limitations but can sometimes be used to get either an abstract syntax tree of the function, or what is essentially a trace of the function for a particular input. All after it has been defined and used. This abstract syntax tree or trace is useful for many purposes. It can be examined, manipulated, and easily converted back into the original function if required.

The trick is this - we call the function we wish to construct the AST for with a proxy object which, as it passes through the function, records the operations which are performed on it.

To see how it works let us create a class called Proxy which extends tuple. We're going to define the Python special methods __add__, __mul__ etc, which allows us to overload the addition and multiplication operators for objects of this type. These overloaded functions construct new Proxy objects with each operand as arguments.

class Proxy(tuple):
    def __add__(self, rhs):  return Proxy(('+', self, rhs))
    def __radd__(self, lhs): return Proxy(('+', lhs, self))
    def __mul__(self, rhs):  return Proxy(('*', self, rhs))
    def __rmul__(self, lhs): return Proxy(('*', lhs, self))

Lets see what happens when we pass a value of this type to our function f as an argument.

>>> f(Proxy('x'))
('+', ('*', 5, ('x',)), 10)

Nice - it returned a tuple which looks a whole lot like the abstract syntax tree of the function (in an S-Expression notation). We can do the same for multi argument functions.

>>> 
>>> def f2(x, y):
>>>     return 3*x*x + 4*y + 2
>>> 
>>> f2(Proxy('x'), Proxy('y'))
('+', ('+', ('*', ('*', 3, ('x',)), ('x',)), ('*', 4, ('y',))), 2)

This is basically the main concept behind the technique. These outputs use strings to represent the operations but it would be good to get rid of these. Instead we can tag each node with a function instead. We should also add a new type to represent inputs so that they don't get confused with actual string objects.

import operator

class Proxy(tuple):
    def __add__(self, rhs):  return Proxy((operator.add, self, rhs))
    def __radd__(self, lhs): return Proxy((operator.add, lhs, self))
    def __mul__(self, rhs):  return Proxy((operator.mul, self, rhs))
    def __rmul__(self, lhs): return Proxy((operator.mul, lhs, self))

class Input(Proxy):
    def __init__(self, name):
        self.name = name
    def __repr__(self):
        return 'Input('+repr(self.name)+')'

Now it can work for strings too.

>>> 
>>> def greet_people(x, y):
>>>     return 'Hello ' + x + ' & ' + y + '!'
>>>
>>> greet_people('Bob', 'Dave')
'Hello Bob & Dave!'
>>> greet_people(Input('x'), Input('y'))
(<built-in function add>,
 (<built-in function add>,
  (<built-in function add>,
   (<built-in function add>, 'Hello ', Input('x')),
   ' & '),
  Input('y')),
 '!')

We can convert from this symbolic representation back into the function representation too. It can be done with a function that looks like the following.

def ast_to_function(ast):
    if not isinstance(ast, Proxy): return lambda **kw: ast
    elif   isinstance(ast, Input): return lambda **kw: kw[ast.name]
    return lambda **kw: ast[0](*[ast_to_function(node)(**kw) for node in ast[1:]])

This looks a little complicated. Let me explain what is going.

First this function checks if the input object is something other than a Proxy object (in which case it is probably something like a number). If it isn't a Proxy then it returns a function that just returns that object.

Secondly it checks if the argument is an Input object. If it is an Input object then it returns a function which looks up the name of that input in the argument keywords and returns that.

Finally if the argument is neither of these it must be a Proxy, so in this case it returns a function which converts into functions and evaluates all of the subexpressions which are part of the Proxy, and returns a function that applies the first element of itself (which should be a function such as <built-in function add>) to the results.

Now we can convert from our AST representation back to a function representation. The only difference is that we have to give the inputs as keyword arguments with the names we gave to our Input objects. We could do positional arguments instead but we'll stick with this setup for simplicity.

>>> greet_people_ast = greet_people(Input('x'), Input('y'))
>>> greet_people_fun = ast_to_function(greet_people_ast)
>>> greet_people_fun(x='Bob', y='Dave')
'Hello Bob & Dave!'

What about functions which aren't setup to take in Proxy types as input (such as the builtin function pow)? We are going to have to edit these functions in some way. Luckily in python we can easily monkey patch them to work with out setup. Here is a function we can use to modify them.

def make_ast_traceable(f):
    def g(*args):
        if any([isinstance(a, Proxy) for a in args]):
            return Proxy((f,) + tuple(args))
        else:
            return f(*args)
    return g

This function takes some input function and returns a similar function which first checks if the arguments are Proxy types and if so returns the AST object, otherwise calls the unmodified function as normal.

All we have to do now is use this function to replace pow and it can appear in our ASTs as normal.

>>> pow = make_ast_traceable(pow)
>>>
>>> def f3(x):
>>>    return pow(3*x*x, 4)
>>>
>>> f3(Input('x'))
(<built-in function pow>,
 (<built-in function mul>,
  (<built-in function mul>, 3, Input('x')),
  Input('x')),
 4)

We'll also have to do this monkey patch treatment for any other functions we wish to appear in our AST. For example we might want to do this to the function we're actually trying to construct the AST for, so that it appears as a reference to itself upon recursion.

This is basically it for this technique - and it really can be used to extract ASTs for simple expressions. But there are a couple of serious limitations to this method.

Unfortunately this method only works with simple expressions - it doesn't work with things like looping or conditionals. Consider a function like the following:

def f4(x):
    if x > 5:
        return pow(x*10, 2)
    else:
        return pow(x+4, 3)

When we pass this our Proxy object and it hits the if statement the function needs to know which way to branch - but this isn't what we want. We don't want to branch at all - in fact what we really want is to extract both branches but there is no way we hook into Python to tell it this. We are forced to pick one branch or the other.

Similar problems exist for looping. What we really want to extract is the fact that a loop is taking place in the code - not to actually evaluate that loop.

Is there anything we can do about this? Not really - there is no way to extract the AST in this case - but we can do something else instead. We can extract a trace of the function for a particular input.

To do this we pair our Proxy object with some input value. Whenever we encounter an operation we perform this operation on our value, and record it in the Proxy. Then, in conditions like this where we need to actually give a value to proceed we can return the value we've been recording and continue the trace.

So we'll need to edit our Proxy object to let it store a value. We'll also define the comparison operators and the __bool__ method for it. Here is the updated Proxy class.

class Proxy:
        
    def __init__(self, value, nodes):
        self.value = value
        self.nodes = nodes
        
    @classmethod
    def op(cls, lhs, rhs, op):
        return Proxy(op(Proxy.value(lhs), Proxy.value(rhs)), (op, lhs, rhs))
        
    @classmethod
    def value(cls, obj):
        return obj.value if isinstance(obj, Proxy) else obj
        
    def __add__(self, rhs): return Proxy.op(self, rhs, operator.add)
    def __sub__(self, rhs): return Proxy.op(self, rhs, operator.sub)
    def __mul__(self, rhs): return Proxy.op(self, rhs, operator.mul)
    def __div__(self, rhs): return Proxy.op(self, rhs, operator.div)

    def __lt__(self, rhs): return Proxy.op(self, rhs, operator.lt)
    def __gt__(self, rhs): return Proxy.op(self, rhs, operator.gt)
    def __le__(self, rhs): return Proxy.op(self, rhs, operator.le)
    def __ge__(self, rhs): return Proxy.op(self, rhs, operator.ge)
        
    def __radd__(self, lhs): return Proxy.op(lhs, self, operator.add)
    def __rsub__(self, lhs): return Proxy.op(lhs, self, operator.sub)
    def __rmul__(self, lhs): return Proxy.op(lhs, self, operator.mul)
    def __rdiv__(self, lhs): return Proxy.op(lhs, self, operator.div)
        
    def __rlt__(self, lhs): return Proxy.op(lhs, self, operator.lt)
    def __rgt__(self, lhs): return Proxy.op(lhs, self, operator.gt)
    def __rle__(self, lhs): return Proxy.op(lhs, self, operator.le)
    def __rge__(self, lhs): return Proxy.op(lhs, self, operator.ge)
        
    def __bool__(self): return self.value
        
    def __repr__(self):
        return 'Proxy(value='+repr(self.value)+', trace='+repr(self.nodes)+')'

We also better update our Input class to store a value too.

class Input(Proxy):
    def __init__(self, name, value):
        self.name = name
        self.value = value
    def __repr__(self):
        return 'Input(name='+repr(self.name)+', value='+repr(self.value)+')'

Now there are two key functions in the new version of Proxy. The first is Proxy.value which extracts the value from a Proxy object, or when passed a different object, uses that object as it is. The second key function is Proxy.op which applies an operator to the values of different proxy objects while recording a trace in the nodes field.

We need to update our make_ast_traceable function too. This function is a bit like Proxy.op. If passed any Proxy objects it extracts their values and performs the given function on them, while simultaneously recording the trace in the nodes field.

def make_ast_traceable(f):
    def g(*args):
        if any([isinstance(a, Proxy) for a in args]):
            return Proxy(f(*map(Proxy.value, args)), (f,) + tuple(args)) 
        else:
            return f(*args)
    return g

And let us replace pow again.

pow = __builtins__.pow
pow = make_ast_traceable(pow)

Now we can extract different traces of the function for different input values.

>>> f4(Input('x', 3))
Proxy(value=343, trace=
  (<built-in function pow>, Proxy(value=7, trace=
   (<built-in function add>, Input(name='x', value=3), 4)), 3))
>>> f4(Input('x', 11))
Proxy(value=12100, trace=
  (<built-in function pow>, Proxy(value=110, trace=
   (<built-in function mul>, Input(name='x', value=11), 10)), 2))

As a side effect we also get all the intermediate values at the different stages of evaluation which is kind of cool. With some more simple changes this method can be extended to work with loops and all sorts of other structures too.

Even so, this method still has limitations. It can't properly record destructive operations and generally doesn't work for operations which don't flow through the whole program in some way. The fact you have to execute the method to extract the trace also might be a problem in some situations.

It might seem like the trace is a bit less useful than the AST but actually it still has many purposes. For example if we wished to calculate the gradient of a function we usually only need it at one particular point - so calculating the derivative for a single trace is perfectly fine.

This use case is exactly how I first came across this idea. It is used in the very cool library NumPy Autograd by Dougal Maclaurin, David Duvenaud, and Matthew Johnson to calculate automatic derivatives of NumPy code for use in numerical optimisation.

And it really works - you can use it to do machine learning in NumPy without ever touching Theano or any other heavyweight tools. The fact it can work with conditionals and loops actually puts it ahead of most straight laced alternatives for calculating automatic derivatives. In Theano and SymPy you have to input your expressions in special, limited forms.

For the full details of the method, including a long explanation of the strengths and limitations please check out the library's GitHub repo.

If you're interested in this idea I think it probably has other interesting applications too. For example it is no coincidence that the traces this method produces resemble the traces used in tracing JIT compilers. Perhaps this method could be used to create a simple JIT compiler in Python for numerical code? Anyhow - if you do come up with any other uses of this method I'd be really interested to hear.

That's about it - Happy Hacking!