Tracing Functions in Python
Created on Nov. 19, 2015, 11:02 a.m.
Lets say we have some function in Python for which we want to do some kind of symbolic manipulation. For example we might have some mathematical function which we wish to automatically calculate the derivative of.
def f(x): return 5*x + 10
Once the function is defined in Python it might seem like it's impossible to do anything to the symbolic representation without doing some difficult things such as manipulating the bytecode or writing some kind of preprocessor.
But there are a couple of cute tricks we can use to extract a symbolic representation of the function. They have some limitations but can sometimes be used to get either an abstract syntax tree of the function, or what is essentially a trace of the function for a particular input. All after it has been defined and used. This abstract syntax tree or trace is useful for many purposes. It can be examined, manipulated, and easily converted back into the original function if required.
The trick is this - we call the function we wish to construct the AST for with a proxy object which, as it passes through the function, records the operations which are performed on it.
To see how it works let us create a class called
Proxy which extends
tuple. We're going to define the Python special methods
__mul__ etc, which allows us to overload the addition and multiplication operators for objects of this type. These overloaded functions construct new
Proxy objects with each operand as arguments.
class Proxy(tuple): def __add__(self, rhs): return Proxy(('+', self, rhs)) def __radd__(self, lhs): return Proxy(('+', lhs, self)) def __mul__(self, rhs): return Proxy(('*', self, rhs)) def __rmul__(self, lhs): return Proxy(('*', lhs, self))
Lets see what happens when we pass a value of this type to our function
>>> f(Proxy('x')) ('+', ('*', 5, ('x',)), 10)
Nice - it returned a tuple which looks a whole lot like the abstract syntax tree of the function (in an S-Expression notation). We can do the same for multi argument functions.
>>> >>> def f2(x, y): >>> return 3*x*x + 4*y + 2 >>> >>> f2(Proxy('x'), Proxy('y')) ('+', ('+', ('*', ('*', 3, ('x',)), ('x',)), ('*', 4, ('y',))), 2)
This is basically the main concept behind the technique. These outputs use strings to represent the operations but it would be good to get rid of these. Instead we can tag each node with a function instead. We should also add a new type to represent inputs so that they don't get confused with actual string objects.
import operator class Proxy(tuple): def __add__(self, rhs): return Proxy((operator.add, self, rhs)) def __radd__(self, lhs): return Proxy((operator.add, lhs, self)) def __mul__(self, rhs): return Proxy((operator.mul, self, rhs)) def __rmul__(self, lhs): return Proxy((operator.mul, lhs, self)) class Input(Proxy): def __init__(self, name): self.name = name def __repr__(self): return 'Input('+repr(self.name)+')'
Now it can work for strings too.
>>> >>> def greet_people(x, y): >>> return 'Hello ' + x + ' & ' + y + '!' >>> >>> greet_people('Bob', 'Dave') 'Hello Bob & Dave!' >>> greet_people(Input('x'), Input('y'))
(<built-in function add>, (<built-in function add>, (<built-in function add>, (<built-in function add>, 'Hello ', Input('x')), ' & '), Input('y')), '!')
We can convert from this symbolic representation back into the function representation too. It can be done with a function that looks like the following.
def ast_to_function(ast): if not isinstance(ast, Proxy): return lambda **kw: ast elif isinstance(ast, Input): return lambda **kw: kw[ast.name] return lambda **kw: ast(*[ast_to_function(node)(**kw) for node in ast[1:]])
This looks a little complicated. Let me explain what is going.
First this function checks if the input object is something other than a
Proxy object (in which case it is probably something like a number). If it
Proxy then it returns a function that just returns that object.
Secondly it checks if the argument is an
Input object. If it is an
object then it returns a function which looks up the name of that input in the
argument keywords and returns that.
Finally if the argument is neither of these it must be a
Proxy, so in this
case it returns a function which converts into functions and evaluates all of
the subexpressions which are part of the
Proxy, and returns a function that
applies the first element of itself (which should be a function such as
<built-in function add>) to the results.
Now we can convert from our AST representation back to a function
representation. The only difference is that we have to give the inputs as
keyword arguments with the names we gave to our
Input objects. We could do
positional arguments instead but we'll stick with this setup for simplicity.
>>> greet_people_ast = greet_people(Input('x'), Input('y')) >>> greet_people_fun = ast_to_function(greet_people_ast) >>> greet_people_fun(x='Bob', y='Dave') 'Hello Bob & Dave!'
What about functions which aren't setup to take in
Proxy types as input (such
as the builtin function
pow)? We are going to have to edit these functions in
some way. Luckily in python we can easily monkey patch them to work with out
setup. Here is a function we can use to modify them.
def make_ast_traceable(f): def g(*args): if any([isinstance(a, Proxy) for a in args]): return Proxy((f,) + tuple(args)) else: return f(*args) return g
This function takes some input function and returns a similar function which
first checks if the arguments are
Proxy types and if so returns the AST
object, otherwise calls the unmodified function as normal.
All we have to do now is use this function to replace
pow and it can appear
in our ASTs as normal.
>>> pow = make_ast_traceable(pow) >>> >>> def f3(x): >>> return pow(3*x*x, 4) >>> >>> f3(Input('x'))
(<built-in function pow>, (<built-in function mul>, (<built-in function mul>, 3, Input('x')), Input('x')), 4)
We'll also have to do this monkey patch treatment for any other functions we wish to appear in our AST. For example we might want to do this to the function we're actually trying to construct the AST for, so that it appears as a reference to itself upon recursion.
This is basically it for this technique - and it really can be used to extract ASTs for simple expressions. But there are a couple of serious limitations to this method.
Unfortunately this method only works with simple expressions - it doesn't work with things like looping or conditionals. Consider a function like the following:
def f4(x): if x > 5: return pow(x*10, 2) else: return pow(x+4, 3)
When we pass this our
Proxy object and it hits the
if statement the
function needs to know which way to branch - but this isn't what we want. We
don't want to branch at all - in fact what we really want is to extract both
branches but there is no way we hook into Python to tell it this. We are forced
to pick one branch or the other.
Similar problems exist for looping. What we really want to extract is the fact that a loop is taking place in the code - not to actually evaluate that loop.
Is there anything we can do about this? Not really - there is no way to extract the AST in this case - but we can do something else instead. We can extract a trace of the function for a particular input.
To do this we pair our
Proxy object with some input value. Whenever we
encounter an operation we perform this operation on our value, and record it in
Proxy. Then, in conditions like this where we need to actually give a
value to proceed we can return the value we've been recording and continue the
So we'll need to edit our
Proxy object to let it store a value. We'll also
define the comparison operators and the
__bool__ method for it. Here is the
class Proxy: def __init__(self, value, nodes): self.value = value self.nodes = nodes @classmethod def op(cls, lhs, rhs, op): return Proxy(op(Proxy.value(lhs), Proxy.value(rhs)), (op, lhs, rhs)) @classmethod def value(cls, obj): return obj.value if isinstance(obj, Proxy) else obj def __add__(self, rhs): return Proxy.op(self, rhs, operator.add) def __sub__(self, rhs): return Proxy.op(self, rhs, operator.sub) def __mul__(self, rhs): return Proxy.op(self, rhs, operator.mul) def __div__(self, rhs): return Proxy.op(self, rhs, operator.div) def __lt__(self, rhs): return Proxy.op(self, rhs, operator.lt) def __gt__(self, rhs): return Proxy.op(self, rhs, operator.gt) def __le__(self, rhs): return Proxy.op(self, rhs, operator.le) def __ge__(self, rhs): return Proxy.op(self, rhs, operator.ge) def __radd__(self, lhs): return Proxy.op(lhs, self, operator.add) def __rsub__(self, lhs): return Proxy.op(lhs, self, operator.sub) def __rmul__(self, lhs): return Proxy.op(lhs, self, operator.mul) def __rdiv__(self, lhs): return Proxy.op(lhs, self, operator.div) def __rlt__(self, lhs): return Proxy.op(lhs, self, operator.lt) def __rgt__(self, lhs): return Proxy.op(lhs, self, operator.gt) def __rle__(self, lhs): return Proxy.op(lhs, self, operator.le) def __rge__(self, lhs): return Proxy.op(lhs, self, operator.ge) def __bool__(self): return self.value def __repr__(self): return 'Proxy(value='+repr(self.value)+', trace='+repr(self.nodes)+')'
We also better update our
Input class to store a value too.
class Input(Proxy): def __init__(self, name, value): self.name = name self.value = value def __repr__(self): return 'Input(name='+repr(self.name)+', value='+repr(self.value)+')'
Now there are two key functions in the new version of
Proxy. The first is
Proxy.value which extracts the value from a
Proxy object, or when passed a
different object, uses that object as it is. The second key function is
which applies an operator to the values of different proxy objects while
recording a trace in the
We need to update our
make_ast_traceable function too. This function is a bit like
Proxy.op. If passed any
Proxy objects it extracts their values and performs the given function on them, while simultaneously recording the trace in the
def make_ast_traceable(f): def g(*args): if any([isinstance(a, Proxy) for a in args]): return Proxy(f(*map(Proxy.value, args)), (f,) + tuple(args)) else: return f(*args) return g
And let us replace
pow = __builtins__.pow pow = make_ast_traceable(pow)
Now we can extract different traces of the function for different input values.
>>> f4(Input('x', 3))
Proxy(value=343, trace= (<built-in function pow>, Proxy(value=7, trace= (<built-in function add>, Input(name='x', value=3), 4)), 3))
>>> f4(Input('x', 11))
Proxy(value=12100, trace= (<built-in function pow>, Proxy(value=110, trace= (<built-in function mul>, Input(name='x', value=11), 10)), 2))
As a side effect we also get all the intermediate values at the different stages of evaluation which is kind of cool. With some more simple changes this method can be extended to work with loops and all sorts of other structures too.
Even so, this method still has limitations. It can't properly record destructive operations and generally doesn't work for operations which don't flow through the whole program in some way. The fact you have to execute the method to extract the trace also might be a problem in some situations.
It might seem like the trace is a bit less useful than the AST but actually it still has many purposes. For example if we wished to calculate the gradient of a function we usually only need it at one particular point - so calculating the derivative for a single trace is perfectly fine.
This use case is exactly how I first came across this idea. It is used in the very cool library NumPy Autograd by Dougal Maclaurin, David Duvenaud, and Matthew Johnson to calculate automatic derivatives of NumPy code for use in numerical optimisation.
And it really works - you can use it to do machine learning in NumPy without ever touching Theano or any other heavyweight tools. The fact it can work with conditionals and loops actually puts it ahead of most straight laced alternatives for calculating automatic derivatives. In Theano and SymPy you have to input your expressions in special, limited forms.
If you're interested in this idea I think it probably has other interesting applications too. For example it is no coincidence that the traces this method produces resemble the traces used in tracing JIT compilers. Perhaps this method could be used to create a simple JIT compiler in Python for numerical code? Anyhow - if you do come up with any other uses of this method I'd be really interested to hear.
That's about it - Happy Hacking!