# Abstract Mathematics

In this chapter, you'll learn about how to perform abstract mathematics, like solving equations expressed in symbols, using code.

If you're running code from this chapter, remember you may need to install the packages. As well as frequently used packages, in this chapter we'll be relying on the **sympy** package for symbolic mathematics.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sympy

In [None]:
import matplotlib_inline.backend_inline

# Plot settings
plt.style.use(
    "https://github.com/aeturrell/coding-for-economists/raw/main/plot_style.txt"
)
matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

# Set max rows displayed for readability
pd.set_option("display.max_rows", 6)


## Symbolic mathematics

In [None]:
from myst_nb import glue

a = 8
glue("sqrt", 2 * np.sqrt(a))
glue("symsqrt", sympy.sqrt(a))

When using computers to do mathematics, we're most often performing numerical computations such as $\sqrt{8} = ${glue:}`sqrt`. Although we have the answer, it's only useful for the one special case. Symbolic mathematics allows us to use coding to solve equations in the general case, which can often be more illuminating. As an example, if we evaluate this in symbolic mathematics we get $\sqrt{8} = ${glue:}`symsqrt`.

The Python package for symbolic mathematics is [**sympy**](https://www.sympy.org/en/index.html), which provides some features of a computer algebra system.

To define *symbolic* variables, we use sympy's symbols function. For ease, we'll import the entire sympy library into the namespace by using `from sympy import *`.

In [None]:
from sympy import *

x, t, α, β = symbols(r"x t \alpha \beta")

```{note}
The leading 'r' in some strings tells Python to treat the string literally so that backslashes are not treated as instructions--otherwise, combinations like `\n` would begin a newline.
```

Having created these symbolic variables, we can refer to and see them just like normal variables--though they're not very interesting *because* they are just symbols (for now):

In [None]:
α

Things get much more interesting when we start to do maths on them. Let's see some integration, for example, say we want to evaluate

In [None]:
Integral(log(x), x)

(note that the symbols are printed as latex equations) we simply call

In [None]:
integrate(log(x), x)

We can differentiate too:

In [None]:
diff(sin(x) * exp(x), x)

and even take limits!

In [None]:
limit(sin(x) / x, x, 0)

It is also possible to solve equations using **sympy**. The solve function tries to find the roots of $f(x)$ and has syntax `solve(f(x)=0, x)`. Here's an example:

In [None]:
solve(x * 5 - 2, x)

There are also solvers for differential equations (`dsolve()`), continued fractions, simplifications, and more.

Another really important thing to know about symbolic mathematics is that you can 'cash in' at any time by substituting in an actual value. For example,

In [None]:
expr = 1 - 2 * sin(x) ** 2
expr.subs(x, np.pi / 2)

But you don't have to substitute in a real value; you can just as well substitute in a different symbolic variable:

In [None]:
expr = 1 - 2 * sin(x) ** 2
simplify(expr.subs(x, t / 2))

I snuck in a simplify here too!

### Symbolic mathematics for economics

The library does a lot, so let's focus on a few features that are likely to be useful for economics in particular.

#### Series expansion

The first is performing **Taylor series expansions**. These come up all the time in macroeconomic modelling, where models are frequently log-linearised. Let's see an example of a couple of expansions together:

In [None]:
expr = log(sin(α))

expr.series(α, 0, 4)

This is a 3rd order expansion around $\alpha=0$.

#### Symbolic linear algebra

The support for **matrices** can also come in handy for economic applications. Here's a matrix,

In [None]:
M = Matrix([[1, 0, x], [α, -t, 3], [4, β, 2]])
M

and its determinant:

In [None]:
M.det()

I can hardly go to a talk in economics that involves matrices that doesn't see those matrices get diagonalised: there's a function for that too.

In [None]:
P, D = Matrix([[1, 0], [α, -t]]).diagonalize()
D

#### Lagrangians

Function optimisation using Lagrangians is about as prevalent in economics as any bit of maths: let's see how it's done symbolically.

We're going to find the minimum over x, y of the function $f(x,y)$, subject to $g(x,y)=0$, where $f(x,y) = 4xy - 2x^2 + y^2$ and $g(x,y) = 3x+y-5$.

First we need to specify the problem, and the Lagrangian for it, in code

In [None]:
x, y, λ = symbols(r"x y \lambda", real=True)
f = 4 * x * y - 2 * x**2 + y**2
g = 3 * x + y - 5

ℒ = f - λ * g
ℒ

The Karush-Kuhn-Tucker (KKT) conditions tell us whether any solutions we find will be optimal. Simply, the constraint is that a solution vector is a saddle point of the Lagrangian, $\nabla \mathcal{L} = 0$. Let's solve this.

In [None]:
gradL = [diff(ℒ, c) for c in [x, y]]
KKT_eqns = gradL + [g]

In [None]:
KKT_eqns = gradL + [g]
glue("kkt_0", KKT_eqns[0])
glue("kkt_1", KKT_eqns[1])
glue("kkt_2", KKT_eqns[2])

This gives 3 equations from the KKT conditions:  {glue:}`kkt_0`,   {glue:}`kkt_1`, and  {glue:}`kkt_2`. (The symbolic manipulation is now over: we solved for the conditions in terms of algebra--now we're looking for real values.) Now we look for the values of $x, y$ that minimise $f$ given that $g=0$ by solving these equations over $x$, $y$, and $\lambda$.

In [None]:
stationary_pts = solve(KKT_eqns, [x, y, λ], dict=True)
stationary_pts

Now, we can substitute these in to find the (first--and in this case only) point that minimises our function:

In [None]:
stationary_pts[0][x], stationary_pts[0][y], f.subs(stationary_pts[0])

#### Exporting to latex

To turn any equation, for example `diff(sin(x)*exp(x), x)`, into latex and export it to a file that can be included in a paper, use

```python
eqn_to_export = latex(diff(sin(x)*exp(x), x), mode='equation')
open('latex_equation.tex', 'w').write(eqn_to_export)
```

which creates a file called 'latex_equation.tex' that has a single line in it: '\begin{equation}\int \log{\left(x \right)}\, dx\end{equation}'. There are a range of options for exporting to latex, `mode='equation*'` produces an unnumbered equation, 'inline' produces an inline equation, and so on. To include these in your latex paper, use '\input{latex_equation.tex}'.

### Why coding symbolic mathematics is useful

1. Accuracy--using a computer to solve the equations means you're less likely to make a mistake. At the very least, it's a useful check on your by-hand working.

2. Consistency--by making your code export the equations you're solving to your write-up, you can ensure that the equations are consistent across both *and* you only have to type them once.

## Sets

Set theory is a surprisingly useful tool in research (and invaluable in spatial analysis). Sets are a first-class citizen in Python, just like lists are.

In the below, we'll see some really useful bits of set theory inspired by examples in {cite:t}`sheppard2012introduction`.

We can define and view a set like this:

In [None]:
x = set(
    [
        "Ada Lovelace",
        "Sadie Alexander",
        "Charles Babbage",
        "Ada Lovelace",
        "Adam Smith",
        "Sadie Alexander",
    ]
)
x

Notice that a couple of entries appeared twice in the list but only once in the set: that's because a set contains only unique elements. Let's define a second set in order to demonstrate some of the operations we can perform on sets.

In [None]:
y = set(
    [
        "Grace Hopper",
        "Jean Bartik",
        "Janet Yellen",
        "Joan Robinson",
        "Adam Smith",
        "Ada Lovelace",
    ]
)
y

In [None]:
from myst_nb import glue

inters = x.intersection(y)
differ = x.difference(y)
union = x.union(y)
glue("inters", inters)
glue("differ", differ)
glue("union", union)

Now we have two sets we can look at to demonstrate some of the basic functions you can call on the set object type. `x.intersection(y)` gives, in this example, {glue:}`inters`, `x.difference(y)` gives {glue:}`differ`, and `x.union(y)` gives {glue:}`union`.

**numpy** also has functions that use set theory. `np.unique()` returns only the unique entries of an input array or list:

In [None]:
np.unique(["Lovelace", "Hopper", "Alexander", "Hopper", 45, 27, 45])

We can also ask which of a second set is a repeat of a first:

In [None]:
x = np.arange(10)
y = np.arange(5, 10)
np.in1d(x, y)

And we have the numpy equivalents of intersection, `np.intersect1d(x, y)`, difference, `np.setdiff1d(x, y)`, and union, `np.union1d(x, y)`. Additionally, there is the exclusive-or (that I like to call 'xor'). This effectively returns the two arrays with their union removed:

In [None]:
a = np.array([1, 2, 3, 2, 4])
b = np.array([2, 3, 5, 7, 5])
np.setxor1d(a, b)

## Advanced: Composable Function Transformations

In recent years, there have been great developments in the ability of Python to easily carry out numerical 'composable function transformations'. What this means is that, if you can dream up an arbitrary numerical operations -- including differentiation, linear algebra, and optimisation -- you can write code that will execute it quickly and automatically on CPUs, GPUs, or TPUs as you like.

Here we'll look at one library that does this, **jax**, developed by Google {cite:ps}`jax2018github`. It can automatically differentiate native Python and **numpy** functions, including when they are in loops, branches, or subject to recursion, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

To do these at speed, it uses just-in-time compilation. If you don't know what that is, don't worry: the details aren't important. It's just a way of getting close to C++ or Fortran speeds while still being able to write code in *much* more user friendly Python!

### Auto-differentiation

Let's see an example of auto-differentiation an arbitrary function. We'll write the definition of $\tanh(x)$ as a function and evaluate it. Because we already imported a (symbolic) `tanh` function from Sympy above, we'll call the function below `tanh_num()`.

```python
from jax import grad
import jax.numpy as jnp

def tanh_num(θ):  # Define a function
  y = jnp.exp(-2.0 * θ)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh_num)  # Obtain its gradient function
grad_tanh(1.0)   # Evaluate it at x = 1.0
```

```bash
DeviceArray(0.4199743, dtype=float32)
```

You can differentiate to any order using grad:

```python
grad(grad(grad(tanh_num)))(1.0)
```

```bash
DeviceArray(0.6216266, dtype=float32)
```

Let's check this using symbolic mathematics:

In [None]:
θ = Symbol(r"\theta")
triple_deriv = diff(diff(diff(tanh(θ), θ)))
triple_deriv

In [None]:
symp_est = triple_deriv.subs(θ, 1.0)
glue("symp_est", f"{symp_est:.3f}")

If we evaluate this at $\theta=1$, we get {glue:}`symp_est`. This was a simple example that had a (relatively) simple mathematical expression. But imagine if we had lots of branches (eg if, else statements), and/or a really complicated function: **jax**'s grad would still work. It's designed for really complex derivatives of the kind encountered in machine learning.

### Just-in-time compilation

The other nice feature of **jax** is the ability to do just-in-time (JIT) compilation. Because they do not compile their code into machine-code before running, high-level languages like Python and R are not as fast as the same code written in C++ or Fortran (the benefit is that it takes you less time to write the code in the first place). Much of the time, there are pre-composed functions that call C++ under the hood to do these things--but only for those operations that people have already taken the time to code up in a lower level language. JIT compilation offers a compromise: you can code more or less as you like in the high-level language but it will be compiled just-in-time to give you a speed-up!

**jax** is certainly not the only Python package that does this, and if you're not doing anything like differentiating or propagating, **numba** is a more mature alternative. But here we'll see the time difference for JIT compilation on an otherwise slow operation: element wise multiplication and addition.

```python
from jax import jit

def slow_f(x):
  """Slow, element-wise function"""
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
```

Now let's see how fast the 'slow' version goes:

```python
%timeit -n15 -r3 slow_f(x)
```

```bash
60.1 ms ± 3.67 ms per loop (mean ± std. dev. of 3 runs, 15 loops each)
```

what about with the JIT compilation?

```python
%timeit -n15 -r3 fast_f(x)
```

```bash
17.7 ms ± 434 µs per loop (mean ± std. dev. of 3 runs, 15 loops each)
```

This short introduction has barely scratched the surface of **jax** and what you can do with it. For more, see the [official documentation](https://jax.readthedocs.io/en/latest/).

## Review

In this chapter, you should have:

- ✅ seen how to use symbolic algebra with code, including Lagrangians and linear algebra; and
- ✅ found out about using set theory via the `set` object type and set-oriented functions.