Abstract Mathematics#

In this chapter, you’ll learn about how to perform abstract mathematics, like solving equations expressed in symbols, using code.

If you’re running code from this chapter, remember you may need to install the packages. As well as frequently used packages, in this chaptwe we’ll be relying on the sympy package for symbolic mathematics.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sympy

Symbolic mathematics#

When using computers to do mathematics, we’re most often performing numerical computations such as \(\sqrt{8} = \)5.656854249492381. Although we have the answer, it’s only useful for the one special case. Symbolic mathematics allows us to use coding to solve equations in the general case, which can often be more illuminating. As an example, if we evaluate this in symbolic mathematics we get \(\sqrt{8} = \)\(\displaystyle 2 \sqrt{2}\).

The Python package for symbolic mathemtics is sympy, which provides some features of a computer algebra system.

To define symbolic variables, we use sympy’s symbols function. For ease, we’ll import the entire sympy library into the namespace by using from sympy import *.

from sympy import *
x, t, α, β = symbols(r'x t \alpha \beta')


The leading ‘r’ in some strings tells Python to treat the string literally so that backslashes are not treated as instructions–otherwise, combinations like \n would begin a newline.

Having created these symbolic variables, we can refer to and see them just like normal variables–though they’re not very interesting because they are just symbols (for now):

\[\displaystyle \alpha\]

Things get much more interesting when we start to do maths on them. Let’s see some integration, for example, say we want to evaluate

Integral(log(x), x)
\[\displaystyle \int \log{\left(x \right)}\, dx\]

(note that the symbols are printed as latex equations) we simply call

integrate(log(x), x)
\[\displaystyle x \log{\left(x \right)} - x\]

We can differentiate too:

diff(sin(x)*exp(x), x)
\[\displaystyle e^{x} \sin{\left(x \right)} + e^{x} \cos{\left(x \right)}\]

and even take limits!

limit(sin(x)/x, x, 0)
\[\displaystyle 1\]

It is also possible to solve equations using sympy. The solve function tries to find the roots of \(f(x)\) and has syntax solve(f(x)=0, x). Here’s an example:

solve(x*5 - 2, x)

There are also solvers for differential equations (dsolve), continued fractions, simplifications, and more.

Another really important thing to know about symbolic mathematics is that you can ‘cash in’ at any time by substituting in an actual value. For example,

expr = 1 - 2*sin(x)**2
expr.subs(x, np.pi/2)
\[\displaystyle -1.0\]

But you don’t have to substitute in a real value; you can just as well substitute in a different symbolic variable:

expr = 1 - 2*sin(x)**2
simplify(expr.subs(x, t/2))
\[\displaystyle \cos{\left(t \right)}\]

I snuck in a simplify here too!

Symbolic mathematics for economics#

The library does a lot, so let’s focus on a few features that are likely to be useful for economics in particular.

Series expansion#

The first is performing Taylor series expansions. These come up all the time in macroeconomic modelling, where models are frequently log-linearised. Let’s see an example of a couple of expansions together:

expr = log(sin(α))

expr.series(α, 0, 4)
\[\displaystyle \log{\left(\alpha \right)} - \frac{\alpha^{2}}{6} + O\left(\alpha^{4}\right)\]

This is a 3rd order expansion around \(\alpha=0\).

Symbolic linear algebra#

The support for matrices can also come in handy for economic applications. Here’s a matrix,

M = Matrix([[1, 0, x], [α, -t, 3], [4, β, 2]])
\[\begin{split}\displaystyle \left[\begin{matrix}1 & 0 & x\\\alpha & - t & 3\\4 & \beta & 2\end{matrix}\right]\end{split}\]

and its determinant:

\[\displaystyle \alpha \beta x - 3 \beta + 4 t x - 2 t\]

I can hardly go to a talk in economics that involves matrices that doesn’t see those matrices get diagonalised: there’s a function for that too.

P, D = Matrix([[1, 0], [α, -t]]).diagonalize()
\[\begin{split}\displaystyle \left[\begin{matrix}1 & 0\\0 & - t\end{matrix}\right]\end{split}\]


Function optimisation using Lagrangians is about as prevalent in economics as any bit of maths: let’s see how it’s done symbolically.

We’re going to find the minimum over x, y of the function \(f(x,y)\), subject to \(g(x,y)=0\), where \(f(x,y) = 4xy - 2x^2 + y^2\) and \(g(x,y) = 3x+y-5\).

First we need to specify the problem, and the Lagrangian for it, in code

x, y, λ = symbols(r'x y \lambda', real=True)
f = 4*x*y - 2*x**2 + y**2
g = 3*x+y-5

 = f - λ*g

\[\displaystyle - \lambda \left(3 x + y - 5\right) - 2 x^{2} + 4 x y + y^{2}\]

The Karush-Kuhn-Tucker (KKT) conditions tell us whether any solutions we find will be optimal. Simply, the constaint is that a solution vector is a saddle point of the Lagrangian, \(\nabla \mathcal{L} = 0\). Let’s solve this.

gradL = [diff(, c) for c in [x, y]]
KKT_eqns = gradL + [g]

This gives 3 equations from the KKT conditions: \(\displaystyle - 3 \lambda - 4 x + 4 y\), \(\displaystyle - \lambda + 4 x + 2 y\), and \(\displaystyle 3 x + y - 5\). (The symbolic manipulation is now over: we solved for the conditions in terms of algebra–now we’re looking for real values.) Now we look for the values of \(x, y\) that minimise \(f\) given that \(g=0\) by solving these equations over \(x\), \(y\), and \(\lambda\).

stationary_pts = solve(KKT_eqns, [x, y, λ], dict=True)
[{x: -1, y: 8, \lambda: 12}]

Now, we can substitute these in to find the (first–and in this case only) point that minimises our function:

stationary_pts[0][x], stationary_pts[0][y], f.subs(stationary_pts[0])
(-1, 8, 30)

Exporting to latex#

To turn any equation, for example diff(sin(x)*exp(x), x), into latex and export it to a file that can be included in a paper, use

eqn_to_export = latex(diff(sin(x)*exp(x), x), mode='equation')
open('latex_equation.tex', 'w').write(eqn_to_export)

which creates a file called ‘latex_equation.tex’ that has a single line in it: ‘\begin{equation}\int \log{\left(x \right)}, dx\end{equation}’. There are a range of options for exporting to latex, mode='equation*' produces an unnumbered equation, ‘inline’ produces an inline equation, and so on. To include these in your latex paper, use ‘\input{latex_equation.tex}’.

Why coding symbolic mathematics is useful#

  1. Accuracy–using a computer to solve the equations means you’re less likely to make a mistake. At the very least, it’s a useful check on your by-hand working.

  2. Consistency–by making your code export the equations you’re solving to your write-up, you can ensure that the equations are consistent across both and you only have to type them once.


Set theory is a surprisingly useful tool in research (and invaluable in spatial analysis). Sets are a first-class citizen in Python, just like lists are.

In the below, we’ll see some really useful bits of set theory inspired by examples in Sheppard [2012].

We can define and view a set like this:

x = set(['Ada Lovelace', 'Sadie Alexander',
         'Charles Babbage', 'Ada Lovelace',
         'Adam Smith', 'Sadie Alexander'])
{'Ada Lovelace', 'Adam Smith', 'Charles Babbage', 'Sadie Alexander'}

Notice that a couple of entries appeared twice in the list but only once in the set: that’s because a set contains only unique elements. Let’s define a second set in order to demonstrate some of the operations we can perform on sets.

y = set(['Grace Hopper', 'Jean Bartik',
         'Janet Yellen', 'Joan Robinson',
         'Adam Smith', 'Ada Lovelace'])
{'Ada Lovelace',
 'Adam Smith',
 'Grace Hopper',
 'Janet Yellen',
 'Jean Bartik',
 'Joan Robinson'}

Now we have two sets we can look at to demonstrate some of the basic functions you can call on the set object type. x.intersection(y) gives, in this example, {'Ada Lovelace', 'Adam Smith'}, x.difference(y) gives {'Charles Babbage', 'Sadie Alexander'}, and x.union(y) gives {'Ada Lovelace', 'Adam Smith', 'Charles Babbage', 'Grace Hopper', 'Janet Yellen', 'Jean Bartik', 'Joan Robinson', 'Sadie Alexander'}.

numpy also has functions that use set theory. np.unique returns only the unique entries of an input array or list:

np.unique(['Lovelace', 'Hopper', 'Alexander', 'Hopper', 45, 27, 45])
array(['27', '45', 'Alexander', 'Hopper', 'Lovelace'], dtype='<U21')

We can also ask which of a second set is a repeat of a first:

x = np.arange(10)
y = np.arange(5, 10)
np.in1d(x, y)
array([False, False, False, False, False,  True,  True,  True,  True,

And we have the numpy equivalents of intersection, np.intersect1d(x, y), difference, np.setdiff1d(x, y), and union, np.union1d(x, y). Additionally, there is the exclusive-or (that I like to call ‘xor’). This effectively returns the two arrays with their union removed:

a = np.array([1, 2, 3, 2, 4])
b = np.array([2, 3, 5, 7, 5])
array([1, 4, 5, 7])

Advanced: Composable Function Transformations#

In recent years, there have been great developments in the ability of Python to easily carry out numerical ‘composable function transformations’. What this means is that, if you can dream up an arbitrary numerical operations – including differentiation, linear algebra, and optimisation – you can write code that will execute it quickly and automatically on CPUs, GPUs, or TPUs as you like.

Here we’ll look at one library that does this, jax, developed by Google Bradbury et al. [2018]. It can automatically differentiate native Python and numpy functions, including when they are in loops, branches, or subject to recursion, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

To do these at speed, it uses just-in-time compilation. If you don’t know what that is, don’t worry: the details aren’t important. It’s just a way of getting close to C++ or Fortran speeds while still being able to write code in much more user friendly Python!


Let’s see an example of auto-differentiation an arbitrary function. We’ll write the definition of \(\tanh(x)\) as a function and evaluate it. Because we already imported a (symbolic) tanh function from Sympy above, we’ll call the function below tanh_num.

from jax import grad
import jax.numpy as jnp

def tanh_num(θ):  # Define a function
  y = jnp.exp(-2.0 * θ)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh_num)  # Obtain its gradient function
grad_tanh(1.0)   # Evaluate it at x = 1.0
DeviceArray(0.4199743, dtype=float32)

You can differentiate to any order using grad:

DeviceArray(0.6216266, dtype=float32)

Let’s check this using symbolic mathematics:

θ = Symbol(r'\theta')
triple_deriv = diff(diff(diff(tanh(θ), θ)))
\[\displaystyle \left(1 - \tanh^{2}{\left(\theta \right)}\right) \left(2 \tanh^{2}{\left(\theta \right)} - 2\right) + 2 \left(2 - 2 \tanh^{2}{\left(\theta \right)}\right) \tanh^{2}{\left(\theta \right)}\]

If we evaluate this at \(\theta=1\), we get '0.622'. This was a simple example that had a (relatively) simple mathematical expression. But imagine if we had lots of branches (eg if, else statements), and/or a really complicated function: jax’s grad would still work. It’s designed for really complex derivatives of the kind encountered in machine learning.

Just-in-time compilation#

The other nice feature of jax is the ability to do just-in-time (JIT) compilation. Because they do not compile their code into machine-code before running, high-level languages like Python and R are not as fast as the same code written in C++ or Fortran (the benefit is that it takes you less time to write the code in the first place). Much of the time, there are pre-composed functions that call C++ under the hood to do these things–but only for those operations that people have already taken the time to code up in a lower level language. JIT compilation offers a compromise: you can code more or less as you like in the high-level language but it will be compiled just-in-time to give you a speed-up!

jax is certainly not the only Python package that does this, and if you’re not doing anything like differentiating or propagating, numba is a more mature alternative. But here we’ll see the time difference for JIT compilation on an otherwise slow operation: element wise multiplication and addition.

from jax import jit

def slow_f(x):
  """Slow, element-wise function"""
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)

Now let’s see how fast the ‘slow’ version goes:

%timeit -n15 -r3 slow_f(x)
60.1 ms ± 3.67 ms per loop (mean ± std. dev. of 3 runs, 15 loops each)

what about with the JIT compilation?

%timeit -n15 -r3 fast_f(x)
17.7 ms ± 434 µs per loop (mean ± std. dev. of 3 runs, 15 loops each)

This short introduction has barely scratched the surface of jax and what you can do with it. For more, see the official documentation.


In this chapter, you should have:

  • ✅ seen how to use symbolic algebra with code, including Lagrangrians and linear algebra; and

  • ✅ found out about using set theory via the set object type and set-oriented functions.