Powerful Data Visualisation with Matplotlib#

Introduction#

Here you’ll see how to make just about any plot you can possibly imagine using the ferociously powerful imperative graphing package, matplotlib. Here, it will be about explaining the basics. If you read on to the chapter on Narrative Data Visualisation, you’ll see just how flexible it is and how it can produce commercial-quality graphics. Additionally, matplotlib is the foundation stone of quite a few other data visualisation libraries (including plotnine and seaborn) which shows how much it can do. It was famously used to create the first ever image of a black hole by Akiyama et al. [2019].

Note

matplotlib is an incredibly powerful and customisable visualisation package.

It’s worth saying that matplotlib has a very different philosophy than, for example, lets-plot. Apart from being imperative rather than declarative, it also prefers unstacked data to tidy data: so, to create plots quickly, instead of every line you want to plot being stored in a single column called “country” as it would be using tidy data, it prefers each column to have a different country in.

For a more in-depth introduction to matplotlib, head over to this tutorial. This chapter is indebted to that tutorial, to the excellent matplotlib documentation, and to the book Scientific Visualization in Matplotlib [Rougier, 2021].

As ever, we’ll start by importing some key packages and initialising any settings:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Set seed for random numbers
seed_for_prng = 78557
prng = np.random.default_rng(
    seed_for_prng
)  # prng=probabilistic random number generator

Understanding Matplotlib#

You’ll see we imported matplotlib.pyplot as plt above; this is the main part of matplotlib that we’ll use in practice.

The matplotlib API#

matplotlib [Hunter, 2007] has its origin in an attempt to replicate the plotting functionality of the paid programming language Matlab but has since outgrown these roots. Because matplotlib has been around a long time and there are different ways to use it, much of the material online about how to use matplotlib uses the old, Matlab-inspired approach—which is not what this book recommends; instead, we strongly recommend the modern object oriented API. If you see a plot feature that begins plt.scatter or similar, then it’s using the legacy API rather than the now recommended “object oriented API” which is focused around axes, eg ax.scatter and similar. You want to follow examples that use the object-oriented approach.

Everything you see below will use the ‘object-oriented API’. This means that we create objects, like figures and axes, that have state (they remember what you did to them).

Also worth saying: the matplotlib API (application programming interface) is huge so there’s no way we’ll be able to cover everything it can do!

The object-oriented API is most often used by creating two fundamental objects that are used by almost every chart: the figure and the axes. You should think of the figure object, fig, as the canvas on which you can put any number of charts. Each ax (short for ‘axis’) object is one chart within a figure. Of course, most of the time you’re likely only to have one axis per figure, but in the cases when you don’t it’s a really useful setup. The plotting of elements such as lines, points, bars, and so, are controlled by the ax objects while the overall settings are controlled by fig.

The most simple chart we can think of would be a line plot. Let’s see how to do that:

fig, ax = plt.subplots()  # Get the figure and one axis as a subplot
ax.plot(
    [1, 2, 3, 4, 5, 6],  # Add some data on the x and y axes
    [1, 4, 2, 3, 1, 7],
)
[<matplotlib.lines.Line2D at 0x13f07a1d0>]
_images/1e5ae296f4b85c33fb98d56f6e121e9c1d9f1ef99047fa711d0b24a58555e501.svg

Notice that, unlike seaborn, matplotlib will happily accept raw data (but will work with dataframes too).

Tip

Matplotlib returns an object when used in certain contexts, eg it might return [<matplotlib.lines.Line2D at... above. To suppress this, end the command with a semi-colon, ;, or call plt.show() as the last command.

Let’s see an example of a scatter plot using this object-oriented approach. Note that we begin in the same way, with getting a figure and an axis. But now we’re going to use ax.scatter instead, plus throw in a couple of extra settings (can you guess what they do?).

fig, ax = plt.subplots()  # Create a figure containing a single axes.
ax.scatter(
    [1, 2, 3, 4, 5, 6], [1, 4, 2, 3, 1, 7], s=150, c="b"  # Plot some data on the axes.
);
_images/d65350feba325519d0c80ad9c03cf0c3cff7deff656fd7da06d4f19aa7b2ea01.svg

s=150 sets the area of the points (ie the size of each marker), and c='b' sets the colour. Many of these features will accept an array of values (like s = [1, 2, 3, 4, 5, 6]) instead of a single value and will map them into the plot in the way you’d expect.

Let’s see that in practice. For the sizes, we’ll linearly increment the points between 300 and 2000 in 6 steps. For the colours, we’ll just type out six distinct values.

fig, ax = plt.subplots()
ax.scatter(
    [1, 2, 3, 4, 5, 6],
    [1, 4, 2, 3, 1, 7],
    s=np.linspace(300, 2000, 6),
    c=["b", "r", "g", "k", "cyan", "yellow"],
    edgecolors="k",
    alpha=0.5,
);
_images/94e75beaa535fe2d71a4950473dd8ce2830f5ba7d06e4dc811e6ff5baab5e46c.svg

We also asked for partly transparent points via the alpha setting (the default is alpha=1, which is a solid colour), and a black (color='k') line-edge colour.

As ever, you can call help on a specific command in order to understand what options (keyword arguments) it accepts. And, in Visual Studio Code, you can just hover over the function or method name.

Here is what help returns when you run it on the scatter() method:

help(ax.scatter)
Help on method scatter in module matplotlib.axes._axes:

scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs) method of matplotlib.axes._axes.Axes instance
    A scatter plot of *y* vs. *x* with varying marker size and/or color.
    
    Parameters
    ----------
    x, y : float or array-like, shape (n, )
        The data positions.
    
    s : float or array-like, shape (n, ), optional
        The marker size in points**2 (typographic points are 1/72 in.).
        Default is ``rcParams['lines.markersize'] ** 2``.
    
        The linewidth and edgecolor can visually interact with the marker
        size, and can lead to artifacts if the marker size is smaller than
        the linewidth.
    
        If the linewidth is greater than 0 and the edgecolor is anything
        but *'none'*, then the effective size of the marker will be
        increased by half the linewidth because the stroke will be centered
        on the edge of the shape.
    
        To eliminate the marker edge either set *linewidth=0* or
        *edgecolor='none'*.
    
    c : array-like or list of colors or color, optional
        The marker colors. Possible values:
    
        - A scalar or sequence of n numbers to be mapped to colors using
          *cmap* and *norm*.
        - A 2D array in which the rows are RGB or RGBA.
        - A sequence of colors of length n.
        - A single color format string.
    
        Note that *c* should not be a single numeric RGB or RGBA sequence
        because that is indistinguishable from an array of values to be
        colormapped. If you want to specify the same RGB or RGBA value for
        all points, use a 2D array with a single row.  Otherwise,
        value-matching will have precedence in case of a size matching with
        *x* and *y*.
    
        If you wish to specify a single color for all points
        prefer the *color* keyword argument.
    
        Defaults to `None`. In that case the marker color is determined
        by the value of *color*, *facecolor* or *facecolors*. In case
        those are not specified or `None`, the marker color is determined
        by the next color of the ``Axes``' current "shape and fill" color
        cycle. This cycle defaults to :rc:`axes.prop_cycle`.
    
    marker : `~.markers.MarkerStyle`, default: :rc:`scatter.marker`
        The marker style. *marker* can be either an instance of the class
        or the text shorthand for a particular marker.
        See :mod:`matplotlib.markers` for more information about marker
        styles.
    
    cmap : str or `~matplotlib.colors.Colormap`, default: :rc:`image.cmap`
        The Colormap instance or registered colormap name used to map scalar data
        to colors.
    
        This parameter is ignored if *c* is RGB(A).
    
    norm : str or `~matplotlib.colors.Normalize`, optional
        The normalization method used to scale scalar data to the [0, 1] range
        before mapping to colors using *cmap*. By default, a linear scaling is
        used, mapping the lowest value to 0 and the highest to 1.
    
        If given, this can be one of the following:
    
        - An instance of `.Normalize` or one of its subclasses
          (see :ref:`colormapnorms`).
        - A scale name, i.e. one of "linear", "log", "symlog", "logit", etc.  For a
          list of available scales, call `matplotlib.scale.get_scale_names()`.
          In that case, a suitable `.Normalize` subclass is dynamically generated
          and instantiated.
    
        This parameter is ignored if *c* is RGB(A).
    
    vmin, vmax : float, optional
        When using scalar data and no explicit *norm*, *vmin* and *vmax* define
        the data range that the colormap covers. By default, the colormap covers
        the complete value range of the supplied data. It is an error to use
        *vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm*
        name together with *vmin*/*vmax* is acceptable).
    
        This parameter is ignored if *c* is RGB(A).
    
    alpha : float, default: None
        The alpha blending value, between 0 (transparent) and 1 (opaque).
    
    linewidths : float or array-like, default: :rc:`lines.linewidth`
        The linewidth of the marker edges. Note: The default *edgecolors*
        is 'face'. You may want to change this as well.
    
    edgecolors : {'face', 'none', *None*} or color or sequence of color, default: :rc:`scatter.edgecolors`
        The edge color of the marker. Possible values:
    
        - 'face': The edge color will always be the same as the face color.
        - 'none': No patch boundary will be drawn.
        - A color or sequence of colors.
    
        For non-filled markers, *edgecolors* is ignored. Instead, the color
        is determined like with 'face', i.e. from *c*, *colors*, or
        *facecolors*.
    
    plotnonfinite : bool, default: False
        Whether to plot points with nonfinite *c* (i.e. ``inf``, ``-inf``
        or ``nan``). If ``True`` the points are drawn with the *bad*
        colormap color (see `.Colormap.set_bad`).
    
    Returns
    -------
    `~matplotlib.collections.PathCollection`
    
    Other Parameters
    ----------------
    data : indexable object, optional
        If given, the following parameters also accept a string ``s``, which is
        interpreted as ``data[s]`` (unless this raises an exception):
    
        *x*, *y*, *s*, *linewidths*, *edgecolors*, *c*, *facecolor*, *facecolors*, *color*
    **kwargs : `~matplotlib.collections.Collection` properties
    
    See Also
    --------
    plot : To plot scatter plots when markers are identical in size and
        color.
    
    Notes
    -----
    * The `.plot` function will be faster for scatterplots where markers
      don't vary in size or color.
    
    * Any or all of *x*, *y*, *s*, and *c* may be masked arrays, in which
      case all masks will be combined and only unmasked points will be
      plotted.
    
    * Fundamentally, scatter works with 1D arrays; *x*, *y*, *s*, and *c*
      may be input as N-D arrays, but within scatter they will be
      flattened. The exception is *c*, which will be flattened only if its
      size matches the size of *x* and *y*.

pandas and matplotlib#

matplotlib will accept any array-like data, but it’s fair to say it much prefers data frames to be in wide data format than tidy format (fortunately pandas makes it easy to switch between the two using unstack() or pivot()). pandas also has built-in plotting methods that make use of matplotlib; these can be accessed via DataFrame.plot.*, for example df.plot().

Let’s see an example of this. First we generate some data:

num_samples = 1000
df = pd.DataFrame(
    prng.standard_normal(size=(num_samples, 4)),
    index=pd.date_range("1/1/2000", periods=num_samples),
    columns=list("ABCD"),
)
df = df.cumsum()
df.head()
A B C D
2000-01-01 0.927389 -0.059754 1.043736 1.505619
2000-01-02 1.035016 -1.920517 3.138396 3.206534
2000-01-03 -0.386094 -0.744485 3.508994 1.914596
2000-01-04 0.977921 -0.272311 3.999495 3.294564
2000-01-05 0.205909 1.476947 2.898448 4.170505

Now, what happens when we use df.plot():

df.plot();
_images/9577fe4e9d519cb804589842e970fbb35b3b84c858163e88bc7a936e0db3bd19.svg

This works pretty much as expected. But what if we had tidy data?

tidy_df = df.reset_index().melt(id_vars="index")
tidy_df.head()
index variable value
0 2000-01-01 A 0.927389
1 2000-01-02 A 1.035016
2 2000-01-03 A -0.386094
3 2000-01-04 A 0.977921
4 2000-01-05 A 0.205909

Let’s just throw this into .plot()

tidy_df.plot();
_images/6382c7e3abf3458fe4f1a18b312b53160f69ba2b328c2b17c170f3011b5b44c2.svg

Uh oh! This isn’t what we wanted at all! It’s combined all of the series. If you have this issue with tidy data, you can switch to wide format like this:

tidy_df.pivot(index="index", columns="variable", values="value")
variable A B C D
index
2000-01-01 0.927389 -0.059754 1.043736 1.505619
2000-01-02 1.035016 -1.920517 3.138396 3.206534
2000-01-03 -0.386094 -0.744485 3.508994 1.914596
... ... ... ... ...
2002-09-24 49.439815 7.261320 -14.740817 -47.441272
2002-09-25 49.765982 7.837380 -14.292353 -48.501030
2002-09-26 50.468400 8.311609 -13.115471 -49.356256

1000 rows × 4 columns

Summary#

There are a huge number of options for what to put on axes; the table below gives a guide to the most essential ones for 2D plots.

Code

What it does

Axes.plot

Plot y versus x as lines and/or markers.

Axes.errorbar

Plot y versus x as lines and/or markers with attached errorbars.

Axes.scatter

A scatter plot of y vs x

Axes.step

Make a step plot.

Axes.loglog

Make a plot with log scaling on both the x and y axis.

Axes.semilogx

Make a plot with log scaling on the x axis.

Axes.semilogy

Make a plot with log scaling on the y axis.

Axes.fill_between

Fill the area between two horizontal curves.

Axes.fill_betweenx

Fill the area between two vertical curves.

Axes.bar

Make a bar plot.

Axes.barh

Make a horizontal bar plot.

Axes.bar_label

Label a bar plot.

Axes.stem

Create a stem plot.

Axes.eventplot

Plot identical parallel lines at the given positions.

Axes.pie

Plot a pie chart.

Axes.stackplot

Draw a stacked area plot.

Axes.broken_barh

Plot a horizontal sequence of rectangles.

Axes.vlines

Plot vertical lines at each x from ymin to ymax.

Axes.hlines

Plot horizontal lines at each y from xmin to xmax.

Axes.axhline

Add a horizontal line across the Axes.

Axes.axhspan

Add a horizontal span (rectangle) across the Axes.

Axes.axvline

Add a vertical line across the Axes.

Axes.axvspan

Add a vertical span (rectangle) across the Axes.

Axes.axline

Add an infinitely long straight line.

Axes.fill

Plot filled polygons.

Axes.boxplot

Draw a box and whisker plot.

Axes.violinplot

Make a violin plot.

Axes.violin

Drawing function for violin plots.

Axes.bxp

Drawing function for box and whisker plots.

Axes.hexbin

Make a 2D hexagonal binning plot of points x, y.

Axes.hist

Compute and plot a histogram.

Axes.hist2d

Make a 2D histogram plot.

Axes.stairs

A stepwise constant function as a line with bounding edges or a filled plot.

Axes.contour

Plot contour lines.

Axes.contourf

Plot filled contours.

Now, it might not seem like it, but you already have the makings of a wide range of matplotlib plots. You just need to follow this recipe:

  1. fig, ax = plt.subplots()

  2. Choose what you want to put on your axis, for example axes[0].scatter or axes[0].plot

  3. Put data into the method you chose in 2 in the required format (and remember you can check the documentation to get the right format; either using help, hovering over the method name in Visual Studio Code, or heading to the matplotlib documentation, where there are often also examples to look at—here are the examples for hex bins).

Anatomy of a matplotlib graph#

So, you’ve made a nice plot and now you want to tweak it. Well, matplotlib certainly has a LOT of choice when it comes to tweaking! Rather than go through all of the many, many options for customisation, the figure below (from the documentation) gives an overview of the options:

Anatomy of a matplotlib figure

Let’s run through a few of the most important plot elements:

  • Artists (not explicitly labelled above): There are two types of Artists: primitives and containers. The primitives represent the standard graphical objects we want to paint onto our canvas: Line2D, Rectangle, Text, AxesImage, etc., and the containers are places to put them (Axis, Axes and Figure).

  • Figure, or ‘fig’: the figure keeps track of all the child Axes that are on it and a smattering of ‘special’ artists (titles, figure legends, etc). A figure can contain any number of Axes, but will typically have at least one if it is to be interesting!

  • Axes, or ‘ax’: this is the plot, the region of the image that traces out the data. A given Figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects (Axes and Axis are different things!) that record the data limits (you can override these via the axes.Axes.set_xlim() and axes.Axes.set_ylim() methods). Each Axes object has a title (set via set_title()), an x-label (set via set_xlabel()), and a y-label set via set_ylabel()). When you add, say, a line chart to an Axes object it appears as a Line2D object associated that that axis and it is created by calling a method on an Axes object.

  • Axis: these are the number-line objects that control the limits of what the viewer can see. They also provide the means to access the ticks (the marks on the axis) and ticklabels (strings labeling the ticks). The location of the ticks is determined by a Locator object and the ticklabel strings are formatted by a Formatter. The combination of Locator and Formatter gives very fine control over the tick locations and labels.

  • Markers: these are what are produced by scatter plots.

  • Labels and titles: text to help the viewers of the chart make sense of what they’re seeing.

  • Legend: if needed, contains the key to understand the shapes or colours used in lines or markers or bars (or …) that are on the chart.

Customising Charts#

While the matplotlib API is too extensive for us to cover every detail, it’s important to know about some standard customisations that you might need to get going.

Limits and Labels#

We’ve only seen aspects of the plot that are customisable through the scatter() keyword so far; let’s now see an example that’s a bit more real (and useful!) in which we’ll want to add labels, a title, and more. Note that matplotlib supports LaTeX equations in text.

We’ll plot some data from the US Midwest demographics dataset.

df = pd.read_csv(
    "https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/midwest.csv",
    index_col="PID",
)
df.head()
rownames county state ... percelderlypoverty inmetro category
PID
561 1 ADAMS IL ... 12.443812 0 AAR
562 2 ALEXANDER IL ... 25.228976 0 LHR
563 3 BOND IL ... 12.697410 0 AAR
564 4 BOONE IL ... 6.217047 1 ALU
565 5 BROWN IL ... 19.200000 0 AAR

5 rows × 28 columns

Now, as well as a scatter, let’s add some context to the chart:

fig, ax = plt.subplots()
ax.scatter(
    df["area"], df["poptotal"], edgecolors="k", alpha=0.6
)  # Make a scatter plot on "ax"
ax.set_xlim(0, 0.1)  # Set the limits on the x-axis
ax.set_ylim(0, 1e6)  # Set the limits on the y-axis
ax.set_xlabel("Area")  # Set the x label
ax.set_ylabel("Population")  # Set the y label
ax.set_title(
    "Area vs. Population", loc="right"
);  # Add a title and say where it should go
_images/d53192fbd5b85e6df3b2097a329ea204cf36a8c356ee7e49eec1f6a0c7d204a7.svg

We’re not quite done with titles. Often it’s useful to have a y-axis title that is horizontal, and so easier to read. And, additionally, it’s good practice to have a title that tells the viewer what they should take away from the graph. To achieve both of these together, you can i) use plt.suptitle() for a figure-level title whose position can be fine-tuned using x and y keyword arguments; and ii) use ax.set_title() to provide the y-axis label. The chart below demonstrates this.

fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"] / 1e6, edgecolors="k", alpha=0.6, s=150)
ax.set_ylim(0, None)
ax.set_xlim(0, None)
ax.set_xlabel("Area")
ax.set_title("Population (millions)", loc="left", fontsize=14)
plt.suptitle("Little correlation between population and area", y=1.02, x=0.45);
_images/c1ab05ee29b33de3db46a1e1356fc0dffd69e3865792372fa2551a8725b29cbd.svg

Customising Axes#

Each Axes has two (or three) Axis objects representing the x- and y-axis. These control the scale of the Axis, the tick locators and the tick formatters. Additional Axes can be attached to display further Axis objects.

First, we’re going to make some random data that will help explain some of the concepts in the rest of the chapter.

data1, data2 = prng.standard_normal(size=(2, 100))  # make 2 random data sets
xdata = np.arange(len(data1))  # make an ordinal for this

Ticks Locators and Formatters#

Each Axis has a tick locator and formatter that choose where along the Axis objects to put tick marks. A simple interface to this is Axes.set_xticks():

fig, axs = plt.subplots(2, 1, layout="constrained")
axs[0].plot(xdata, data1)
axs[0].set_title("Automatic ticks")

axs[1].plot(xdata, data1)
axs[1].set_xticks(np.arange(0, 100, 30), ["zero", "30", "sixty", "90"])
axs[1].set_yticks([-1.5, 0, 1.5])  # note that we don't need to specify labels
axs[1].set_title("Manual ticks");
_images/8e3b82ca849a3ca02037ebec13550a4b239681a3ac656d9ba12a5827b73f78ca.svg

See the matplotlib documentation on Tick locators and Tick formatters for other formatters and locators and information for writing your own.

Returning to our Midwest examples, we had the makings of a basic chart, with axes labels and even a title.

But what we know from the data is that “Area” is actually a percentage. We could represent this by doing

ax.set_xlabel("Area, %")

but as matplotlib is infinitely customisable, there is another option too—changing the tick labels.

On the x-axis, we’ll add a percentage suffix on the numbers plus some minor tick marks.

from matplotlib.ticker import AutoMinorLocator

fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="right")
ax.xaxis.set_minor_locator(
    AutoMinorLocator(4)
)  # Add minor tick marks, four between every major one
ax.xaxis.set_major_formatter(
    "{x:.2f}%"
);  # Every x value has 2 decimal places and is followed by a '%' sign
_images/77edbda0af6caad72bccd9b451176b0adf3eeeecc4b5ec661d0b540462132966.svg

The AutoMinorLocator(4) inserts 4 minor tick marks between each major tick mark. The major formatter is '{x:.2f}%', which says print 2 decimal places followed by a % sign.

Scales#

In addition to the linear scale, matplotlib supplies non-linear scales, such as a log-scale. Since log-scales are used so much there are also direct methods like Axes.loglog, Axes.semilogx, and Axes.semilogy. There are a number of scales (see the matplotlib scales docs for other examples). Here we set the scale manually:

fig, axs = plt.subplots(1, 2, layout="constrained")
xdata = np.arange(len(data1))  # make an ordinal for this
data = 10**data1
axs[0].plot(xdata, data, lw=1)
axs[0].set_ylim(0, None)
axs[1].set_yscale("log")
axs[1].plot(xdata, data, lw=1);
_images/519aeabbe77145007de91ff313a5200d7c2800b6caaeb4bc312e4acd558b08ea.svg

Using Loops to Achieve Customisation Effects#

Let’s say we want to now differentiate these points with colour according to which state they belong to and add a legend that says which states have which colour. The easiest way to do this is by creating a for loop.

We’re going to loop over state. We do it using for state in df["state"].unique():, which runs over states. The different colours are provided by passing cmap="colourmap-name". This generates a different colour for each state and cycles to the next colour each time ax.scatter() is called in the loop.

Now we have colour capturing a new dimension of the data (state), we also need to have a legend that shows which colour each state corresponds to. By passing lable=state within the for loop, we build up state<==>colour equivalences that appears on the chart when we call ax.legend.

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter("{x:.2f}%")
ax.legend(title="State", loc="upper right");
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/4007070724.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(
_images/3fccff54ec8696a976653a173595332718b4876bab061f9bda0a2cf31a51361e.svg

We used a colourmap to get 5 qualitatively different colours; there are also sequential colormaps for continuous (as opposed to discrete) variables. You can find out more about the colormaps available in base matplotlib here.

Adding Text Annotations#

It’s possible to point out specific values with a text label. Adding extra information like this can be useful in all kinds of circumstances; for example, showing the biggest or smallest, drawing attention to a particular story, or simply flagging a special value.

Let’s add a couple of text annotations. Let’s say we want to annotate the county with the biggest area, and the county with the highest population. For the biggest area, we’ll just pop the label next to the point. First we need to find the position of the point:

max_area_row = df.iloc[df["area"].argmax()]
max_area_row
rownames                    246
county                MARQUETTE
state                        MI
                        ...    
percelderlypoverty    12.523891
inmetro                       0
category                    HAR
Name: 1248, Length: 28, dtype: object

Now we use ax.annotate() to add this information.

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/124428108.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(
_images/00cf9734184759d3b12c6651bf30953502256f52c270b9973d09f15e8b88bfab.svg

What if we want to put text somewhere other than right next to the datapoint? We can do that and have an arrow to connect the label to the data point.

max_pop_row = df.iloc[df["poptotal"].argmax()]

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
    xytext=(-100, 20),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.annotate(
    text=f'Max. pop: {max_pop_row["county"].title()}',
    xy=tuple(max_pop_row[["area", "poptotal"]]),
    xytext=(-100, -50),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/2614635337.py:6: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(
_images/3dd3fe6c119ead9ffecb0ac0a0d574e73af012b55901dd2a3a4926d14c72068d.svg

You can get really creative with text annotations—as the image below, taken from Scientific Visualisation: Python + Matplotlib [Rougier, 2021] shows. You can learn more about text annotations here.

_images/562e875a9ed1d26f76f8fe3d2e995bafa2bfd5e6ae752400a5c8200d6dd81acc.svg

Special Lines#

Next, we’re going to add some special lines to our chart. This is surprisingly useful in practice, and there are a few commands. If we just want a horizontal or vertical line, our best bet is ax.axhline() and ax.axvline() respectively.

Let’s add a special line to our example to show where the mean area of counties appears. First we need the mean area:

mean_county_area = df["area"].mean()
mean_county_area
0.03316933638443936

Now we’re going to add the line, using ax.axvline, but also a corresponding annotation that tells viewers what this line is showing.

fig, ax = plt.subplots()
for state in df["state"].unique():
    xf = df.loc[df["state"] == state]
    ax.scatter(
        xf["area"],
        xf["poptotal"],
        cmap="Dark2",  # Note that we have qualitative data so we use a discrete colour map
        label=state,
        s=100,
        edgecolor="k",
        alpha=0.8,
    )
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
    text=f'Max. area: {max_area_row.loc["county"].title()}',
    xy=tuple(max_area_row[["area", "poptotal"]]),
    xytext=(-100, 20),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.annotate(
    text=f'Max. pop: {max_pop_row["county"].title()}',
    xy=tuple(max_pop_row[["area", "poptotal"]]),
    xytext=(20, -50),
    textcoords="offset points",
    arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.axvline(x=mean_county_area, linewidth=0.5, linestyle="-.")
ax.annotate(
    "Mean county area",
    xy=(mean_county_area, 0.5),
    xycoords=("data", "axes fraction"),
    rotation=-90,
    fontsize=11,
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/3286688189.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
  ax.scatter(
_images/803e8af24ea5881d6ed17a6940dbd8cb967132133b112a39171fa3e232f9a29e.svg

Note that the second co-ordinate of the text annotation is in terms of fraction of the figure’s y-axis rather than in the co-ordinates of the data. This is useful because now, whatever else changes in the chart, we know the text will appear half-way up the axis. There are several different co-ordinate systems that matplotlib accepts depending on what you’re trying to achieve:

  • the co-ordinates of the data, eg area and population in the chart above

  • the fraction of the figure axes, aka the “axes fraction”

  • offset points, relative to another point (used for text relative to a data point)

  • various options using pixels

  • and more, including polar co-ordinates

You can find out a bit more about the different co-ordinate systems on the matplotlib documentation.

Styling Artists#

Most plotting methods have styling options for the Artists, accessible either when a plotting method is called, or from a “setter” on the Artist. In the plot below we manually set the color, linewidth, and linestyle of the Artists created by Axes.plot(), and we set the linestyle of the second line after the fact with Line2D.set_linestyle().

fig, ax = plt.subplots()
x = np.arange(len(data1))
ax.plot(x, np.cumsum(data1), color="blue", linewidth=3, linestyle="--")
(l,) = ax.plot(x, np.cumsum(data2), color="orange", linewidth=2)
l.set_linestyle(":");
_images/76fbb69244d37f8ebbe2873f5b8017e0d408f2efac92f971a6214d268e245c7e.svg

Multiple Charts on one Figure#

One basic bit of functionality that you might need is to put more than one type of information on a single overall figure. This could mean having multiple subplots or it could be about putting more information on a single set of axes.

Mutiples features on one set of axes#

This is the kind of task where matplotlib’s build-what-you-want philosophy starts to win out. To add another feature on an ax that you’ve already created is as simple as calling ax.<method> again. You can add as many features as you like.

In the example below, we’ll call ax.hist() followed by ax.plot() to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using numpy.

While we’re at it, let’s add an equation that describes the theoretical curve.

rand_draws = prng.standard_normal(5000)
grid_x = np.linspace(-5, 5, 1000)

fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True, label="Data")
ax.plot(
    grid_x,
    1 / np.sqrt(2 * np.pi) * np.exp(-(grid_x**2) / 2),
    linewidth=4,
    label=r"$\frac{1}{2\pi}e^{-\frac{x^2}{2}}$",
)
ax.legend(fontsize=14);
_images/44828aed092ffff8a40fbeaf8cfc149de5aba272ed62bce66766c9c153791282.svg

Facets#

Another way to put more information on a single figure is to have different facets. To illustrate some of the ideas we’re about to see, it’s going to be useful to have some data. Let’s pull down GDP per capita for a selection of countries.

from pandas_datareader import wb

start_year = 2000
end_year = 2022

df_cc_gdp = wb.download(
    indicator=["NY.GDP.PCAP.KD"],  # GDP per capita in 2015 USD
    country=["US", "GB", "FR", "DE", "IT", "JP"],
    start=start_year,
    end=end_year,
)
df_cc_gdp = df_cc_gdp.reset_index()  # drop country as index (for plots)
df_cc_gdp["year"] = df_cc_gdp["year"].astype("int")  # ensure year is a number
# sort by year, then country (important for next step)
df_cc_gdp = df_cc_gdp.sort_values(["year", "country"])
# index to that country's value in first entry by country
df_cc_gdp["GDP per capita"] = (
    100
    * df_cc_gdp["NY.GDP.PCAP.KD"]
    / df_cc_gdp.groupby("country")["NY.GDP.PCAP.KD"].transform(lambda row: row.iloc[0])
)
df_cc_gdp.head()
country year NY.GDP.PCAP.KD GDP per capita
45 France 2000 33592.466830 100.0
22 Germany 2000 34490.075769 100.0
91 Italy 2000 32350.904367 100.0
114 Japan 2000 31430.631130 100.0
68 United Kingdom 2000 38918.455057 100.0

Subplots#

The first way to put data on multiple charts that are part of the same overall figure is to use the built-in subplot function. plt.subplots() accepts arguments for nrows= and ncols= that we can specify the number of figures with. We have six countries so let’s do two rows of three columns.

We can use a for loop to go over these. Actually, though, the structure of the axes object that comes back from plt.subplots() is a 2x3 matrix (or array), and you can’t loop over that. But we can loop over a “flattened” version of it, which you can create with axes.flatten().

Once we’re in the loop, we subset the data by country for each loop and add it to the chart using ax.plot(). Finally, we’ll use the same limits for every subplot to make the chart easier to read.

Let’s see this with the data we just pulled down.

fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
    country = df_cc_gdp["country"].unique()[i]
    country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
    ax.plot(
        country_df["year"],
        country_df["GDP per capita"],
        lw=3,
        color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
    )
    ax.set_title(country, loc="center", fontsize=13)
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year, end_year)
    ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
_images/52297c2e1eb8bfe3938c011ce34e29a6a5eb2c54b0ec9fc7af9602202012b238.svg

A nice extra you can do with this sort of chart is to add the other countries but greyed out so that it’s clear which country is featured but the cross-country comparison is a bit easier.

fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
    country = df_cc_gdp["country"].unique()[i]
    # grab the other countries
    other_countries = [x for x in df_cc_gdp["country"].unique() if x != country]
    for other in other_countries:
        o_country_df = df_cc_gdp.loc[df_cc_gdp["country"] == other, :]
        ax.plot(
            o_country_df["year"],
            o_country_df["GDP per capita"],
            lw=1,
            color="k",
            alpha=0.1,
        )
    country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
    ax.plot(
        country_df["year"],
        country_df["GDP per capita"],
        lw=3,
        color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
    )
    ax.set_title(country, loc="center", fontsize=13)
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year, end_year)
    ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
_images/36c3857578d1214fcb0acf0253e2d294861fbfec8b474e888d2d3a44276699f2.svg

Gridspec#

This is a more complicated way of creating multiple subplots, but it is more flexible too: although it still relies on an underlying grid structure, the units of that grid don’t have to all equate to a single subplot. It’s easier to see than to describe!

Here’s an example that shows what’s possible:

_images/2316a0557abd4d6cb9b5919f89bbcb8d0a17afb35979cdb60cd5bb7df751f614.svg

Let’s say we wanted to really feature a single country more than the others: we might want to give it a full height spot. We can do this with gridspec.

Taking our example, let’s “focus” on the UK. We’re going to use the layout above, with one corner having the UK in and covering the grid from units 1 to 2 on both dimensions. Meanwhile, positions (0, 0), (0, 1), (1, 0), (2, 0), and (0, 2) will have the other five countries on.

We can produce this pattern using code: first a double list comprehension to pick out those specific values, then reducing it to a single list of co-ordinates (because double list comprehensions lead to nested lists). Here’s the set of positions for the other five countries in code:

import itertools

nested_list = [
    [(i, j) for i in [0, 1, 2] if (i != j or i == 0) and (i + j <= 2)]
    for j in [0, 1, 2]
]
other_country_indices = list(itertools.chain(*nested_list))
other_country_indices
[(0, 0), (1, 0), (2, 0), (0, 1), (0, 2)]

It’s a bit labourious to do this in code, but it’s often good practice to think about how to solve the general problem rather than manually typing out the solution because the former scales better to other problems. As you code more, you’ll find yourself writing these little helper lines again and again in different contexts. GitHub even has a service called gists to help you keep track of code snippets that may be useful across lots of projects.

Now let’s convert our grid into some plots!

import matplotlib.gridspec as gridspec

fig = plt.figure(constrained_layout=True)
nrows, ncols = 3, 3
gspec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
# create a list of the five spots we don't want to feature UK on
non_uk_axes = [fig.add_subplot(gspec[i, j]) for i, j in other_country_indices]
uk_ax = fig.add_subplot(gspec[1:, 1:])
all_axes = non_uk_axes + [uk_ax]
# Plot the other countries
o_countries_df = df_cc_gdp.loc[df_cc_gdp["country"] != "United Kingdom", :]
o_country_names = o_countries_df["country"].unique()
for i, ax in enumerate(non_uk_axes):
    # grab the other country's data
    o_country = o_country_names[i]
    o_country_df = o_countries_df.loc[o_countries_df["country"] == o_country, :]
    ax.plot(
        o_country_df["year"],
        o_country_df["GDP per capita"],
        lw=3,
        color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
    )
    ax.set_title(o_country, loc="center", fontsize=13)
# Now do UK
uk_df = df_cc_gdp.loc[df_cc_gdp["country"] == "United Kingdom", :]
uk_ax.set_title("United Kingdom", loc="center", fontsize=15)
uk_ax.plot(
    uk_df["year"],
    uk_df["GDP per capita"],
    lw=3,
    color=plt.rcParams["axes.prop_cycle"].by_key()["color"][len(o_country_names)],
)
# global settings
for ax in all_axes:
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year, end_year)
    ax.set_ylim(100, 150)

plt.show()
_images/8246157ddf621c24ea8a72855afa5b60f55efb8bed0eec3875461bd26c31ee36.svg

You should note, though, that this kind of chart can be quite confusing and—generally—it’s better to have consistent axes and subplot sizes if you’re trying to make comparisons easy.

Using Subplot Mosaic to Specify a Layout#

We told you that the matplotlib API was ferociously powerful. Well, subplot grids and GridSpec aren’t the only ways to build up multiple subplots: you can build a layout of Axes based on ASCII art or nested lists! The method that does this, subplot_mosaic(), is a helper function to build complex GridSpec layouts visually.

Here’s an example of using text to specify a layout:

axd = plt.figure(constrained_layout=True).subplot_mosaic(
    """
    ABD
    CCD
    CC.
    """
)
kw = dict(ha="center", va="center", fontsize=60, color="darkgrey")
for k, ax in axd.items():
    ax.text(0.5, 0.5, k, transform=ax.transAxes, **kw)
_images/c5fa77008cf9f3f87df34e1002ece4955ccf530b2dae2c2d0f2de8c0b804c689.svg

Note how we did a little trick with the font settings too: we put them in a keyword argument dictionary called kw and passed these into the ax.text() function with the “splatty-splat” operator, which unpacks them.

Let’s now see another way to use subplot_mosiac(), this time using lists to achieve a particular layout:

ax_dict = plt.figure(constrained_layout=True).subplot_mosaic(
    [["A", "B", "C", "C", "D", "D"]]
)
kw = dict(ha="center", va="center", fontsize=30, color="darkgrey")
for k, ax in ax_dict.items():
    ax.text(0.5, 0.5, k, transform=ax.transAxes, **kw)
_images/1d77a95f374ca12b7f3373cf9bea54e1c1d246dbe4bd0333350555faf75551da.svg

An Advanced Example: Zoomed Plots#

This is taken from the online book Scientific Visualization in Matplotlib [Rougier, 2021]. It’s quite a complex example, but does show what can be achieved.

from matplotlib.gridspec import GridSpec
from matplotlib.patches import Rectangle
from matplotlib.patches import ConnectionPatch


fig = plt.figure(figsize=(5, 4))

n = 5
gs = GridSpec(n, n + 1)

ax = plt.subplot(
    gs[:n, :n], xlim=[-1, +1], xticks=[], ylim=[-1, +1], yticks=[], aspect=1
)

X = prng.normal(0, 0.35, 1000)
Y = prng.normal(0, 0.35, 1000)
ax.scatter(X, Y, edgecolor="None", facecolor="C1", alpha=0.5)

I = prng.choice(len(X), size=n, replace=False)
Px, Py = X[I], Y[I]
I = np.argsort(Y[I])[::-1]
Px, Py = Px[I], Py[I]

ax.scatter(Px, Py, edgecolor="black", facecolor="None", linewidth=0.75)

dx, dy = 0.075, 0.075
for i, (x, y) in enumerate(zip(Px, Py)):
    sax = plt.subplot(
        gs[i, n],
        xlim=[x - dx, x + dx],
        xticks=[],
        ylim=[y - dy, y + dy],
        yticks=[],
        aspect=1,
    )
    sax.scatter(X, Y, edgecolor="None", facecolor="C1", alpha=0.5)
    sax.scatter(Px, Py, edgecolor="black", facecolor="None", linewidth=0.75)

    sax.text(
        1.1,
        0.5,
        "Point " + chr(ord("A") + i),
        rotation=90,
        size=8,
        ha="left",
        va="center",
        transform=sax.transAxes,
    )

    rect = Rectangle(
        (x - dx, y - dy),
        2 * dx,
        2 * dy,
        edgecolor="black",
        facecolor="None",
        linestyle="--",
        linewidth=0.75,
    )
    ax.add_patch(rect)

    con = ConnectionPatch(
        xyA=(x, y),
        coordsA=ax.transData,
        xyB=(0, 0.5),
        coordsB=sax.transAxes,
        linestyle="--",
        linewidth=0.75,
        patchA=rect,
        arrowstyle="->",
    )
    fig.add_artist(con)
_images/66837cf46392ab955d87f677caed902f8eee4b454c949cb53ca48b28f559a858.svg

Working with Times and Dates#

Python has really amazing support for “datetimes”, and you can find out more about the general capabilities in Introduction to Time.

It’s also very important to know how to make charts with times and dates in economics, so we’re going to give this one some special attention.

To demonstrate some of the principles, we’re going to work with real data on UK and US GDP growth. First let’s pull in that data and take a look at it

from pandas_datareader import wb
from datetime import datetime

ts_start_date = pd.to_datetime("1999-01-01")
ts_end_date = datetime.now()
countries = ["GBR", "USA"]
gdf_const_2015_usd_code = 'NY.GDP.MKTP.KD'
df = wb.download(indicator=gdf_const_2015_usd_code, country=countries, start=ts_start_date, end=ts_end_date).reset_index()
df = df.sort_values(by="year") # this is very important!
# note that we are computing the percent change on a by-country basis!
# Otherwise, you'd mix up the UK and US values
df["growth, %"] = df.groupby("country")[gdf_const_2015_usd_code].transform(lambda x: 100*x.pct_change(1))
df = df.reset_index(drop=True)
df.head()
country year NY.GDP.MKTP.KD growth, %
0 United Kingdom 1999 2.197118e+12 NaN
1 United States 1999 1.321548e+13 NaN
2 United States 2000 1.375430e+13 4.077159
3 United Kingdom 2000 2.292006e+12 4.318727
4 United Kingdom 2001 2.351111e+12 2.578741

Now, you can see that this is annual data, and what’s been returned by the World Bank has given us a “year” field that is of type “object”, basically the lowest common denominator data type:

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48 entries, 0 to 47
Data columns (total 4 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   country         48 non-null     object 
 1   year            48 non-null     object 
 2   NY.GDP.MKTP.KD  48 non-null     float64
 3   growth, %       46 non-null     float64
dtypes: float64(2), object(2)
memory usage: 1.6+ KB

We’re going to turn it back into a datetime. This is a special type that helps with dates and times. In the below code snippet, the year we get from running pd.to_datetime() defaults to the 1st of January. But really, these dates represent the year that has already passed, so we’ll end on a time offset.

df["year"] = pd.to_datetime(df["year"], format="%Y") + pd.offsets.BYearEnd()
df.head()
country year NY.GDP.MKTP.KD growth, %
0 United Kingdom 1999-12-31 2.197118e+12 NaN
1 United States 1999-12-31 1.321548e+13 NaN
2 United States 2000-12-29 1.375430e+13 4.077159
3 United Kingdom 2000-12-29 2.292006e+12 4.318727
4 United Kingdom 2001-12-31 2.351111e+12 2.578741

Okay, we’re now going to do a couple of other things that will help matplotlib. Unlike, say, lets-plot, matplotlib expects each line of interest in your data to appear across columns, ie in wide format. So we’re going to transform our data to be wide, and make the datetime column, "year", into the index.

df_wide = df.pivot(index="year", values="growth, %", columns="country")
df_wide = df_wide.sort_index()
df_wide.head()
country United Kingdom United States
year
1999-12-31 NaN NaN
2000-12-29 4.318727 4.077159
2001-12-31 2.578741 0.954339
2002-12-31 1.791986 1.695943
2003-12-31 3.146573 2.796209

We also set the index to be the year, beause

By running df_wide.info(), we can learn what data types we now have.

df_wide.info()
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 24 entries, 1999-12-31 to 2022-12-30
Data columns (total 2 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   United Kingdom  23 non-null     float64
 1   United States   23 non-null     float64
dtypes: float64(2)
memory usage: 576.0 bytes

Now, note that i) the index is composed of dates, and ii) those dates have datatype “DatetimeIndex”.

Alright, let’s just make a plot here without worrying too much about what might happen. We’ll do a line chart showing the growth rates over time of these two large economies.

fig, ax = plt.subplots()
df_wide.plot(ax=ax, lw=3)
ax.set_title("GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side
_images/cb1588a29a4662f81e3b156c2e5693c16c2fe69919e8aff45cb935f77b49d388.svg

You’ll notice straight away that because we began with a datetime index, when we called the convenience function df.plot(ax=...), we automatically got an x-axis that understands and renders datetimes. More than that, the scale is sensible: it’s only showing years on the major tick labels, even though we have quarterly data.

Also note that the data are in wide format, which matplotlib prefers.

We could achieve the same outcome without using pandas built-in plotting methods, though it’s a tad more verbose (we need two separate commands for the two countries and to explicitly add a legend):

fig, ax = plt.subplots()
for name in ["United Kingdom", "United States"]:
    ax.plot(df_wide.index, df_wide[name], lw=3, label=name)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend()
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side
_images/e3d33ce128ca5928e4ec36c8d6c696f31001513439e7c75289f95d5bcb48e1f2.svg

Now let’s see what happens if we change the axis limits. Because we’re using datetime data type for the x-axis variable, we’ll need to pass a datetime in to the limits too.

fig, ax = plt.subplots()
df_wide.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right()  # Put tick marks and tick labels on right-hand side
ax.set_xlim(pd.to_datetime("2018-01-01"), pd.to_datetime("2021-01-01"));
_images/7a1097fd0b453919877bf1213759cf9c40676ce5e0a74f3885d63b650eba80c5.svg

As you can see, the plot dynamically responded to the shorter time period by putting more details in, here of quarters (as the minor ticks). You can specify exactly what you want with the tick label formatters that cater to datetimes, but the defaults are pretty well-behaved.

Advanced Topics#

Styles#

It can feel a bit tedious to apply different styles to every single ax, subplot, and even figure! Fortunately, there are ways to make applying styles more efficient. In this section, we’ll see how to do this concisely.

There are essentially four ways to apply styles:

  1. modify each ax object one-by-one (works at the axis level)

  2. use a dictionary once to pass keyword arguments to all objects of a similar type (works at the axis level)

  3. temporarily change the global style (works at the figure level)

  4. change the global style (works on every figure!)

We’ve seen plenty of 1. already, so we won’t do more on that. Let’s home in on 2. first.

Consistent Styling Across Axes and Subplots using Dictionaries#

Let’s say we have multiple text annotations that we want to make consistent in most details. We can use the cross-country GDP per capita example from earlier. Essentially, we define a dictionary of all the properties (and here, even a nested dictionary with further keyword arguments), and pass this using the splaty-splat (**) operator.

kw_settings = dict(
    xycoords="data",
    ha="center",
    va="center",
    fontweight="bold",
    fontsize=12,
    xytext=(-20, +50),
    textcoords="offset points",
    arrowprops=dict(
        arrowstyle="->",
        connectionstyle="arc3,rad=-0.3",
        alpha=0.7,
    ),
    bbox=dict(facecolor="white", edgecolor="None", alpha=0.85),
)

# spread out x text annotations for each country
years_array = np.arange(start_year + 1, end_year, len(df_cc_gdp["country"].unique()) - 2)
# get countries in order of lowest to highest end value
countries = (
    df_cc_gdp.sort_values("GDP per capita")
    .loc[df_cc_gdp["year"] == df_cc_gdp["year"].max(), "country"]
    .values
)


fig, ax = plt.subplots()
for i, country in enumerate(countries):
    country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
    ax.plot(
        country_df["year"],
        country_df["GDP per capita"],
        lw=3,
        color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
    )
    ax.annotate(
        country,
        (
            years_array[i],
            country_df.loc[country_df["year"] == years_array[i], "GDP per capita"],
        ),
        **kw_settings,
        color=plt.gca().lines[-1].get_color(),
    )


ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year - 1, end_year)
ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max() * 1.05)
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
/Users/aet/mambaforge/envs/codeforecon/lib/python3.10/site-packages/matplotlib/text.py:1463: FutureWarning: Calling float on a single element Series is deprecated and will raise a TypeError in the future. Use float(ser.iloc[0]) instead
  y = float(self.convert_yunits(y))
_images/345a59995c206c983a56acc7c235f63482eb3a2077f9d00e408afca11d9584a5.svg

Temporarily Change the Global Style#

Python has a powerful notion called contexts that can change the apparent global settings to any code run within them. You can identify them by the with command. Now, we’ll show how to use these to completely change the style of a plot.

The first, simple example is

data = prng.standard_normal(50)
fig, ax = plt.subplots()
ax.plot(data)
plt.show()
_images/972b35653631ee0389c519dce6b432d3c058e416434e21c38b330d03f6ed159f.svg

Now let’s say, just for this plot, we want to make a change. We can do this by setting the “rc context”. This changes global settings that relate to figures. You can pass keyword arguments to it as dictionary items. In this case, we’ll change the linewidth and the linestyle. The with command means any indented code following the with gets changed. But, after the indented block is over, everything reverts back to the usual settings.

You can find a full list of the settings you can pass here.

import matplotlib as mpl

with mpl.rc_context({'lines.linewidth': 4, 'lines.linestyle': ':'}):
    fig, ax = plt.subplots()
    ax.plot(data)
    plt.show()
_images/5be876ed4b466487355104be210d271931088d7cd4807ec109d994b523619b56.svg

There are some built-in styles that you can take advantage of for this too. These are called via plt.style.context(<style name>), and you can find a list here.

with plt.style.context('grayscale'):
    fig, ax = plt.subplots()
    for i, country in enumerate(countries):
        country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
        ax.plot(
            country_df["year"],
            country_df["GDP per capita"],
            lw=3,
        )
        ax.annotate(
            country,
            (
                years_array[i],
                country_df.loc[country_df["year"] == years_array[i], "GDP per capita"],
            ),
            **kw_settings,
            color=plt.gca().lines[-1].get_color(),
        )
    ax.yaxis.tick_right()
    ax.spines["right"].set_visible(True)
    ax.spines["left"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.set_xlim(start_year - 1, end_year)
    ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max() * 1.05)
    fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
    plt.tight_layout();
/Users/aet/mambaforge/envs/codeforecon/lib/python3.10/site-packages/matplotlib/text.py:1463: FutureWarning: Calling float on a single element Series is deprecated and will raise a TypeError in the future. Use float(ser.iloc[0]) instead
  y = float(self.convert_yunits(y))
_images/ddc7ddc0fea7e0a5b116c23345e516972ff38c790542a7c9ec9543201f8e50fe.svg

Setting the Style for an Entire Script or Notebook#

This book uses a special style file to set its RC params and over-ride the default values found in the default rc file. Once loaded, using

plt.style.use(
    "plot_style.txt"
)

all charts use whatever settings are in “plot_style.txt” unless those settings are explicitly over-ridden by other commands.

You can see the first few lines of the “plot_style.txt” file used for this book below, or you can browse it yourself here.

!head -10 plot_style.txt
xtick.color: 323034
ytick.color: 323034
text.color: 323034
lines.markeredgecolor: black
patch.facecolor        : bc80bd
patch.force_edgecolor  : True
patch.linewidth: 0.8
scatter.edgecolors: black
grid.color: b1afb5
axes.titlesize: 19

Review#

That concludes our tour of the basics of matplotlib. You should now feel comfortable making basic charts using data and

You can find much more advanced uses of it over at the matplotlib documentation and in the online book Scientific Visualization in Matplotlib [Rougier, 2021].