Powerful Data Visualisation with Matplotlib#
Introduction#
Here you’ll see how to make just about any plot you can possibly imagine using the ferociously powerful imperative graphing package, matplotlib. Here, it will be about explaining the basics. If you read on to the chapter on Narrative Data Visualisation, you’ll see just how flexible it is and how it can produce commercial-quality graphics. Additionally, matplotlib is the foundation stone of quite a few other data visualisation libraries (including plotnine and seaborn) which shows how much it can do. It was famously used to create the first ever image of a black hole by Akiyama et al. [2019].
Note
matplotlib is an incredibly powerful and customisable visualisation package.
It’s worth saying that matplotlib has a very different philosophy than, for example, lets-plot. Apart from being imperative rather than declarative, it also prefers unstacked data to tidy data: so, to create plots quickly, instead of every line you want to plot being stored in a single column called “country” as it would be using tidy data, it prefers each column to have a different country in.
For a more in-depth introduction to matplotlib, head over to this tutorial. This chapter is indebted to that tutorial, to the excellent matplotlib documentation, and to the book Scientific Visualization in Matplotlib [Rougier, 2021].
As ever, we’ll start by importing some key packages and initialising any settings:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Set seed for random numbers
seed_for_prng = 78557
prng = np.random.default_rng(
seed_for_prng
) # prng=probabilistic random number generator
Understanding Matplotlib#
You’ll see we imported matplotlib.pyplot as plt
above; this is the main part of matplotlib that we’ll use in practice.
The matplotlib API#
matplotlib [Hunter, 2007] has its origin in an attempt to replicate the plotting functionality of the paid programming language Matlab but has since outgrown these roots. Because matplotlib has been around a long time and there are different ways to use it, much of the material online about how to use matplotlib uses the old, Matlab-inspired approach—which is not what this book recommends; instead, we strongly recommend the modern object oriented API. If you see a plot feature that begins plt.scatter
or similar, then it’s using the legacy API rather than the now recommended “object oriented API” which is focused around axes, eg ax.scatter
and similar. You want to follow examples that use the object-oriented approach.
Everything you see below will use the ‘object-oriented API’. This means that we create objects, like figures and axes, that have state (they remember what you did to them).
Also worth saying: the matplotlib API (application programming interface) is huge so there’s no way we’ll be able to cover everything it can do!
The object-oriented API is most often used by creating two fundamental objects that are used by almost every chart: the figure and the axes. You should think of the figure object, fig
, as the canvas on which you can put any number of charts. Each ax (short for ‘axis’) object is one chart within a figure. Of course, most of the time you’re likely only to have one axis per figure, but in the cases when you don’t it’s a really useful setup. The plotting of elements such as lines, points, bars, and so, are controlled by the ax
objects while the overall settings are controlled by fig
.
The most simple chart we can think of would be a line plot. Let’s see how to do that:
fig, ax = plt.subplots() # Get the figure and one axis as a subplot
ax.plot(
[1, 2, 3, 4, 5, 6], # Add some data on the x and y axes
[1, 4, 2, 3, 1, 7],
)
[<matplotlib.lines.Line2D at 0x13f07a1d0>]
Notice that, unlike seaborn, matplotlib will happily accept raw data (but will work with dataframes too).
Tip
Matplotlib returns an object when used in certain contexts, eg it might return [<matplotlib.lines.Line2D at...
above. To suppress this, end the command with a semi-colon, ;
, or call plt.show()
as the last command.
Let’s see an example of a scatter plot using this object-oriented approach. Note that we begin in the same way, with getting a figure and an axis. But now we’re going to use ax.scatter
instead, plus throw in a couple of extra settings (can you guess what they do?).
fig, ax = plt.subplots() # Create a figure containing a single axes.
ax.scatter(
[1, 2, 3, 4, 5, 6], [1, 4, 2, 3, 1, 7], s=150, c="b" # Plot some data on the axes.
);
s=150
sets the area of the points (ie the size of each marker), and c='b'
sets the colour. Many of these features will accept an array of values (like s = [1, 2, 3, 4, 5, 6]
) instead of a single value and will map them into the plot in the way you’d expect.
Let’s see that in practice. For the sizes, we’ll linearly increment the points between 300 and 2000 in 6 steps. For the colours, we’ll just type out six distinct values.
fig, ax = plt.subplots()
ax.scatter(
[1, 2, 3, 4, 5, 6],
[1, 4, 2, 3, 1, 7],
s=np.linspace(300, 2000, 6),
c=["b", "r", "g", "k", "cyan", "yellow"],
edgecolors="k",
alpha=0.5,
);
We also asked for partly transparent points via the alpha
setting (the default is alpha=1
, which is a solid colour), and a black (color='k'
) line-edge colour.
As ever, you can call help on a specific command in order to understand what options (keyword arguments) it accepts. And, in Visual Studio Code, you can just hover over the function or method name.
Here is what help
returns when you run it on the scatter()
method:
help(ax.scatter)
Help on method scatter in module matplotlib.axes._axes:
scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs) method of matplotlib.axes._axes.Axes instance
A scatter plot of *y* vs. *x* with varying marker size and/or color.
Parameters
----------
x, y : float or array-like, shape (n, )
The data positions.
s : float or array-like, shape (n, ), optional
The marker size in points**2 (typographic points are 1/72 in.).
Default is ``rcParams['lines.markersize'] ** 2``.
The linewidth and edgecolor can visually interact with the marker
size, and can lead to artifacts if the marker size is smaller than
the linewidth.
If the linewidth is greater than 0 and the edgecolor is anything
but *'none'*, then the effective size of the marker will be
increased by half the linewidth because the stroke will be centered
on the edge of the shape.
To eliminate the marker edge either set *linewidth=0* or
*edgecolor='none'*.
c : array-like or list of colors or color, optional
The marker colors. Possible values:
- A scalar or sequence of n numbers to be mapped to colors using
*cmap* and *norm*.
- A 2D array in which the rows are RGB or RGBA.
- A sequence of colors of length n.
- A single color format string.
Note that *c* should not be a single numeric RGB or RGBA sequence
because that is indistinguishable from an array of values to be
colormapped. If you want to specify the same RGB or RGBA value for
all points, use a 2D array with a single row. Otherwise,
value-matching will have precedence in case of a size matching with
*x* and *y*.
If you wish to specify a single color for all points
prefer the *color* keyword argument.
Defaults to `None`. In that case the marker color is determined
by the value of *color*, *facecolor* or *facecolors*. In case
those are not specified or `None`, the marker color is determined
by the next color of the ``Axes``' current "shape and fill" color
cycle. This cycle defaults to :rc:`axes.prop_cycle`.
marker : `~.markers.MarkerStyle`, default: :rc:`scatter.marker`
The marker style. *marker* can be either an instance of the class
or the text shorthand for a particular marker.
See :mod:`matplotlib.markers` for more information about marker
styles.
cmap : str or `~matplotlib.colors.Colormap`, default: :rc:`image.cmap`
The Colormap instance or registered colormap name used to map scalar data
to colors.
This parameter is ignored if *c* is RGB(A).
norm : str or `~matplotlib.colors.Normalize`, optional
The normalization method used to scale scalar data to the [0, 1] range
before mapping to colors using *cmap*. By default, a linear scaling is
used, mapping the lowest value to 0 and the highest to 1.
If given, this can be one of the following:
- An instance of `.Normalize` or one of its subclasses
(see :ref:`colormapnorms`).
- A scale name, i.e. one of "linear", "log", "symlog", "logit", etc. For a
list of available scales, call `matplotlib.scale.get_scale_names()`.
In that case, a suitable `.Normalize` subclass is dynamically generated
and instantiated.
This parameter is ignored if *c* is RGB(A).
vmin, vmax : float, optional
When using scalar data and no explicit *norm*, *vmin* and *vmax* define
the data range that the colormap covers. By default, the colormap covers
the complete value range of the supplied data. It is an error to use
*vmin*/*vmax* when a *norm* instance is given (but using a `str` *norm*
name together with *vmin*/*vmax* is acceptable).
This parameter is ignored if *c* is RGB(A).
alpha : float, default: None
The alpha blending value, between 0 (transparent) and 1 (opaque).
linewidths : float or array-like, default: :rc:`lines.linewidth`
The linewidth of the marker edges. Note: The default *edgecolors*
is 'face'. You may want to change this as well.
edgecolors : {'face', 'none', *None*} or color or sequence of color, default: :rc:`scatter.edgecolors`
The edge color of the marker. Possible values:
- 'face': The edge color will always be the same as the face color.
- 'none': No patch boundary will be drawn.
- A color or sequence of colors.
For non-filled markers, *edgecolors* is ignored. Instead, the color
is determined like with 'face', i.e. from *c*, *colors*, or
*facecolors*.
plotnonfinite : bool, default: False
Whether to plot points with nonfinite *c* (i.e. ``inf``, ``-inf``
or ``nan``). If ``True`` the points are drawn with the *bad*
colormap color (see `.Colormap.set_bad`).
Returns
-------
`~matplotlib.collections.PathCollection`
Other Parameters
----------------
data : indexable object, optional
If given, the following parameters also accept a string ``s``, which is
interpreted as ``data[s]`` (unless this raises an exception):
*x*, *y*, *s*, *linewidths*, *edgecolors*, *c*, *facecolor*, *facecolors*, *color*
**kwargs : `~matplotlib.collections.Collection` properties
See Also
--------
plot : To plot scatter plots when markers are identical in size and
color.
Notes
-----
* The `.plot` function will be faster for scatterplots where markers
don't vary in size or color.
* Any or all of *x*, *y*, *s*, and *c* may be masked arrays, in which
case all masks will be combined and only unmasked points will be
plotted.
* Fundamentally, scatter works with 1D arrays; *x*, *y*, *s*, and *c*
may be input as N-D arrays, but within scatter they will be
flattened. The exception is *c*, which will be flattened only if its
size matches the size of *x* and *y*.
pandas and matplotlib#
matplotlib will accept any array-like data, but it’s fair to say it much prefers data frames to be in wide data format than tidy format (fortunately pandas makes it easy to switch between the two using unstack()
or pivot()
). pandas also has built-in plotting methods that make use of matplotlib; these can be accessed via DataFrame.plot.*
, for example df.plot()
.
Let’s see an example of this. First we generate some data:
num_samples = 1000
df = pd.DataFrame(
prng.standard_normal(size=(num_samples, 4)),
index=pd.date_range("1/1/2000", periods=num_samples),
columns=list("ABCD"),
)
df = df.cumsum()
df.head()
A | B | C | D | |
---|---|---|---|---|
2000-01-01 | 0.927389 | -0.059754 | 1.043736 | 1.505619 |
2000-01-02 | 1.035016 | -1.920517 | 3.138396 | 3.206534 |
2000-01-03 | -0.386094 | -0.744485 | 3.508994 | 1.914596 |
2000-01-04 | 0.977921 | -0.272311 | 3.999495 | 3.294564 |
2000-01-05 | 0.205909 | 1.476947 | 2.898448 | 4.170505 |
Now, what happens when we use df.plot()
:
df.plot();
This works pretty much as expected. But what if we had tidy data?
tidy_df = df.reset_index().melt(id_vars="index")
tidy_df.head()
index | variable | value | |
---|---|---|---|
0 | 2000-01-01 | A | 0.927389 |
1 | 2000-01-02 | A | 1.035016 |
2 | 2000-01-03 | A | -0.386094 |
3 | 2000-01-04 | A | 0.977921 |
4 | 2000-01-05 | A | 0.205909 |
Let’s just throw this into .plot()
tidy_df.plot();
Uh oh! This isn’t what we wanted at all! It’s combined all of the series. If you have this issue with tidy data, you can switch to wide format like this:
tidy_df.pivot(index="index", columns="variable", values="value")
variable | A | B | C | D |
---|---|---|---|---|
index | ||||
2000-01-01 | 0.927389 | -0.059754 | 1.043736 | 1.505619 |
2000-01-02 | 1.035016 | -1.920517 | 3.138396 | 3.206534 |
2000-01-03 | -0.386094 | -0.744485 | 3.508994 | 1.914596 |
... | ... | ... | ... | ... |
2002-09-24 | 49.439815 | 7.261320 | -14.740817 | -47.441272 |
2002-09-25 | 49.765982 | 7.837380 | -14.292353 | -48.501030 |
2002-09-26 | 50.468400 | 8.311609 | -13.115471 | -49.356256 |
1000 rows × 4 columns
Summary#
There are a huge number of options for what to put on axes; the table below gives a guide to the most essential ones for 2D plots.
Code |
What it does |
---|---|
|
Plot y versus x as lines and/or markers. |
|
Plot y versus x as lines and/or markers with attached errorbars. |
|
A scatter plot of y vs x |
|
Make a step plot. |
|
Make a plot with log scaling on both the x and y axis. |
|
Make a plot with log scaling on the x axis. |
|
Make a plot with log scaling on the y axis. |
|
Fill the area between two horizontal curves. |
|
Fill the area between two vertical curves. |
|
Make a bar plot. |
|
Make a horizontal bar plot. |
|
Label a bar plot. |
|
Create a stem plot. |
|
Plot identical parallel lines at the given positions. |
|
Plot a pie chart. |
|
Draw a stacked area plot. |
|
Plot a horizontal sequence of rectangles. |
|
Plot vertical lines at each x from ymin to ymax. |
|
Plot horizontal lines at each y from xmin to xmax. |
|
Add a horizontal line across the Axes. |
|
Add a horizontal span (rectangle) across the Axes. |
|
Add a vertical line across the Axes. |
|
Add a vertical span (rectangle) across the Axes. |
|
Add an infinitely long straight line. |
|
Plot filled polygons. |
|
Draw a box and whisker plot. |
|
Make a violin plot. |
|
Drawing function for violin plots. |
|
Drawing function for box and whisker plots. |
|
Make a 2D hexagonal binning plot of points x, y. |
|
Compute and plot a histogram. |
|
Make a 2D histogram plot. |
|
A stepwise constant function as a line with bounding edges or a filled plot. |
|
Plot contour lines. |
|
Plot filled contours. |
Now, it might not seem like it, but you already have the makings of a wide range of matplotlib plots. You just need to follow this recipe:
fig, ax = plt.subplots()
Choose what you want to put on your axis, for example
axes[0].scatter
oraxes[0].plot
Put data into the method you chose in 2 in the required format (and remember you can check the documentation to get the right format; either using
help
, hovering over the method name in Visual Studio Code, or heading to the matplotlib documentation, where there are often also examples to look at—here are the examples for hex bins).
Anatomy of a matplotlib graph#
So, you’ve made a nice plot and now you want to tweak it. Well, matplotlib certainly has a LOT of choice when it comes to tweaking! Rather than go through all of the many, many options for customisation, the figure below (from the documentation) gives an overview of the options:
Let’s run through a few of the most important plot elements:
Artists (not explicitly labelled above): There are two types of Artists: primitives and containers. The primitives represent the standard graphical objects we want to paint onto our canvas:
Line2D
,Rectangle
,Text
,AxesImage
, etc., and the containers are places to put them (Axis
,Axes
andFigure
).Figure, or ‘fig’: the figure keeps track of all the child Axes that are on it and a smattering of ‘special’ artists (titles, figure legends, etc). A figure can contain any number of Axes, but will typically have at least one if it is to be interesting!
Axes, or ‘ax’: this is the plot, the region of the image that traces out the data. A given Figure can contain many Axes, but a given Axes object can only be in one Figure. The Axes contains two (or three in the case of 3D) Axis objects (Axes and Axis are different things!) that record the data limits (you can override these via the
axes.Axes.set_xlim()
andaxes.Axes.set_ylim()
methods). Each Axes object has a title (set viaset_title()
), an x-label (set viaset_xlabel()
), and a y-label set viaset_ylabel()
). When you add, say, a line chart to an Axes object it appears as a Line2D object associated that that axis and it is created by calling a method on an Axes object.Axis: these are the number-line objects that control the limits of what the viewer can see. They also provide the means to access the ticks (the marks on the axis) and ticklabels (strings labeling the ticks). The location of the ticks is determined by a Locator object and the ticklabel strings are formatted by a Formatter. The combination of Locator and Formatter gives very fine control over the tick locations and labels.
Markers: these are what are produced by scatter plots.
Labels and titles: text to help the viewers of the chart make sense of what they’re seeing.
Legend: if needed, contains the key to understand the shapes or colours used in lines or markers or bars (or …) that are on the chart.
Customising Charts#
While the matplotlib API is too extensive for us to cover every detail, it’s important to know about some standard customisations that you might need to get going.
Limits and Labels#
We’ve only seen aspects of the plot that are customisable through the scatter()
keyword so far; let’s now see an example that’s a bit more real (and useful!) in which we’ll want to add labels, a title, and more. Note that matplotlib supports LaTeX equations in text.
We’ll plot some data from the US Midwest demographics dataset.
df = pd.read_csv(
"https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2/midwest.csv",
index_col="PID",
)
df.head()
rownames | county | state | ... | percelderlypoverty | inmetro | category | |
---|---|---|---|---|---|---|---|
PID | |||||||
561 | 1 | ADAMS | IL | ... | 12.443812 | 0 | AAR |
562 | 2 | ALEXANDER | IL | ... | 25.228976 | 0 | LHR |
563 | 3 | BOND | IL | ... | 12.697410 | 0 | AAR |
564 | 4 | BOONE | IL | ... | 6.217047 | 1 | ALU |
565 | 5 | BROWN | IL | ... | 19.200000 | 0 | AAR |
5 rows × 28 columns
Now, as well as a scatter, let’s add some context to the chart:
fig, ax = plt.subplots()
ax.scatter(
df["area"], df["poptotal"], edgecolors="k", alpha=0.6
) # Make a scatter plot on "ax"
ax.set_xlim(0, 0.1) # Set the limits on the x-axis
ax.set_ylim(0, 1e6) # Set the limits on the y-axis
ax.set_xlabel("Area") # Set the x label
ax.set_ylabel("Population") # Set the y label
ax.set_title(
"Area vs. Population", loc="right"
); # Add a title and say where it should go
We’re not quite done with titles. Often it’s useful to have a y-axis title that is horizontal, and so easier to read. And, additionally, it’s good practice to have a title that tells the viewer what they should take away from the graph. To achieve both of these together, you can i) use plt.suptitle()
for a figure-level title whose position can be fine-tuned using x
and y
keyword arguments; and ii) use ax.set_title()
to provide the y-axis label. The chart below demonstrates this.
fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"] / 1e6, edgecolors="k", alpha=0.6, s=150)
ax.set_ylim(0, None)
ax.set_xlim(0, None)
ax.set_xlabel("Area")
ax.set_title("Population (millions)", loc="left", fontsize=14)
plt.suptitle("Little correlation between population and area", y=1.02, x=0.45);
Customising Axes#
Each Axes has two (or three) Axis objects representing the x- and y-axis. These control the scale of the Axis, the tick locators and the tick formatters. Additional Axes can be attached to display further Axis objects.
First, we’re going to make some random data that will help explain some of the concepts in the rest of the chapter.
data1, data2 = prng.standard_normal(size=(2, 100)) # make 2 random data sets
xdata = np.arange(len(data1)) # make an ordinal for this
Ticks Locators and Formatters#
Each Axis has a tick locator and formatter that choose where along the Axis objects to put tick marks. A simple interface to this is Axes.set_xticks()
:
fig, axs = plt.subplots(2, 1, layout="constrained")
axs[0].plot(xdata, data1)
axs[0].set_title("Automatic ticks")
axs[1].plot(xdata, data1)
axs[1].set_xticks(np.arange(0, 100, 30), ["zero", "30", "sixty", "90"])
axs[1].set_yticks([-1.5, 0, 1.5]) # note that we don't need to specify labels
axs[1].set_title("Manual ticks");
See the matplotlib documentation on Tick locators and Tick formatters for other formatters and locators and information for writing your own.
Returning to our Midwest examples, we had the makings of a basic chart, with axes labels and even a title.
But what we know from the data is that “Area” is actually a percentage. We could represent this by doing
ax.set_xlabel("Area, %")
but as matplotlib is infinitely customisable, there is another option too—changing the tick labels.
On the x-axis, we’ll add a percentage suffix on the numbers plus some minor tick marks.
from matplotlib.ticker import AutoMinorLocator
fig, ax = plt.subplots()
ax.scatter(df["area"], df["poptotal"], edgecolors="k", alpha=0.6)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="right")
ax.xaxis.set_minor_locator(
AutoMinorLocator(4)
) # Add minor tick marks, four between every major one
ax.xaxis.set_major_formatter(
"{x:.2f}%"
); # Every x value has 2 decimal places and is followed by a '%' sign
The AutoMinorLocator(4)
inserts 4 minor tick marks between each major tick mark. The major formatter is '{x:.2f}%'
, which says print 2 decimal places followed by a % sign.
Scales#
In addition to the linear scale, matplotlib supplies non-linear scales, such as a log-scale. Since log-scales are used so much there are also direct methods like Axes.loglog
, Axes.semilogx
, and
Axes.semilogy
. There are a number of scales (see the matplotlib scales docs for other examples). Here we set the scale
manually:
fig, axs = plt.subplots(1, 2, layout="constrained")
xdata = np.arange(len(data1)) # make an ordinal for this
data = 10**data1
axs[0].plot(xdata, data, lw=1)
axs[0].set_ylim(0, None)
axs[1].set_yscale("log")
axs[1].plot(xdata, data, lw=1);
Using Loops to Achieve Customisation Effects#
Let’s say we want to now differentiate these points with colour according to which state they belong to and add a legend that says which states have which colour. The easiest way to do this is by creating a for
loop.
We’re going to loop over state
. We do it using for state in df["state"].unique():
, which runs over states. The different colours are provided by passing cmap="colourmap-name"
. This generates a different colour for each state and cycles to the next colour each time ax.scatter()
is called in the loop.
Now we have colour capturing a new dimension of the data (state), we also need to have a legend that shows which colour each state corresponds to. By passing lable=state
within the for
loop, we build up state<==>colour equivalences that appears on the chart when we call ax.legend
.
fig, ax = plt.subplots()
for state in df["state"].unique():
xf = df.loc[df["state"] == state]
ax.scatter(
xf["area"],
xf["poptotal"],
cmap="Dark2", # Note that we have qualitative data so we use a colour map
label=state,
s=100,
edgecolor="k",
alpha=0.8,
)
ax.set_xlim(0, 0.1)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.xaxis.set_minor_locator(AutoMinorLocator(4))
ax.xaxis.set_major_formatter("{x:.2f}%")
ax.legend(title="State", loc="upper right");
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/4007070724.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter(
We used a colourmap to get 5 qualitatively different colours; there are also sequential colormaps for continuous (as opposed to discrete) variables. You can find out more about the colormaps available in base matplotlib here.
Adding Text Annotations#
It’s possible to point out specific values with a text label. Adding extra information like this can be useful in all kinds of circumstances; for example, showing the biggest or smallest, drawing attention to a particular story, or simply flagging a special value.
Let’s add a couple of text annotations. Let’s say we want to annotate the county with the biggest area, and the county with the highest population. For the biggest area, we’ll just pop the label next to the point. First we need to find the position of the point:
max_area_row = df.iloc[df["area"].argmax()]
max_area_row
rownames 246
county MARQUETTE
state MI
...
percelderlypoverty 12.523891
inmetro 0
category HAR
Name: 1248, Length: 28, dtype: object
Now we use ax.annotate()
to add this information.
fig, ax = plt.subplots()
for state in df["state"].unique():
xf = df.loc[df["state"] == state]
ax.scatter(
xf["area"],
xf["poptotal"],
cmap="Dark2", # Note that we have qualitative data so we use a colour map
label=state,
s=100,
edgecolor="k",
alpha=0.8,
)
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 1e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
text=f'Max. area: {max_area_row.loc["county"].title()}',
xy=tuple(max_area_row[["area", "poptotal"]]),
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/124428108.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter(
What if we want to put text somewhere other than right next to the datapoint? We can do that and have an arrow to connect the label to the data point.
max_pop_row = df.iloc[df["poptotal"].argmax()]
fig, ax = plt.subplots()
for state in df["state"].unique():
xf = df.loc[df["state"] == state]
ax.scatter(
xf["area"],
xf["poptotal"],
cmap="Dark2", # Note that we have qualitative data so we use a colour map
label=state,
s=100,
edgecolor="k",
alpha=0.8,
)
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
text=f'Max. area: {max_area_row.loc["county"].title()}',
xy=tuple(max_area_row[["area", "poptotal"]]),
xytext=(-100, 20),
textcoords="offset points",
arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.annotate(
text=f'Max. pop: {max_pop_row["county"].title()}',
xy=tuple(max_pop_row[["area", "poptotal"]]),
xytext=(-100, -50),
textcoords="offset points",
arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/2614635337.py:6: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter(
You can get really creative with text annotations—as the image below, taken from Scientific Visualisation: Python + Matplotlib [Rougier, 2021] shows. You can learn more about text annotations here.
Special Lines#
Next, we’re going to add some special lines to our chart. This is surprisingly useful in practice, and there are a few commands. If we just want a horizontal or vertical line, our best bet is ax.axhline()
and ax.axvline()
respectively.
Let’s add a special line to our example to show where the mean area of counties appears. First we need the mean area:
mean_county_area = df["area"].mean()
mean_county_area
0.03316933638443936
Now we’re going to add the line, using ax.axvline
, but also a corresponding annotation that tells viewers what this line is showing.
fig, ax = plt.subplots()
for state in df["state"].unique():
xf = df.loc[df["state"] == state]
ax.scatter(
xf["area"],
xf["poptotal"],
cmap="Dark2", # Note that we have qualitative data so we use a discrete colour map
label=state,
s=100,
edgecolor="k",
alpha=0.8,
)
ax.set_xlim(0, 0.12)
ax.set_ylim(0, 6e6)
ax.set_xlabel("Area")
ax.set_ylabel("Population")
ax.set_title("Area vs. Population", loc="center")
ax.legend()
ax.annotate(
text=f'Max. area: {max_area_row.loc["county"].title()}',
xy=tuple(max_area_row[["area", "poptotal"]]),
xytext=(-100, 20),
textcoords="offset points",
arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.annotate(
text=f'Max. pop: {max_pop_row["county"].title()}',
xy=tuple(max_pop_row[["area", "poptotal"]]),
xytext=(20, -50),
textcoords="offset points",
arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
ax.axvline(x=mean_county_area, linewidth=0.5, linestyle="-.")
ax.annotate(
"Mean county area",
xy=(mean_county_area, 0.5),
xycoords=("data", "axes fraction"),
rotation=-90,
fontsize=11,
);
/var/folders/x6/ffnr59f116l96_y0q0bjfz7c0000gn/T/ipykernel_25576/3286688189.py:4: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
ax.scatter(
Note that the second co-ordinate of the text annotation is in terms of fraction of the figure’s y-axis rather than in the co-ordinates of the data. This is useful because now, whatever else changes in the chart, we know the text will appear half-way up the axis. There are several different co-ordinate systems that matplotlib accepts depending on what you’re trying to achieve:
the co-ordinates of the data, eg area and population in the chart above
the fraction of the figure axes, aka the “axes fraction”
offset points, relative to another point (used for text relative to a data point)
various options using pixels
and more, including polar co-ordinates
You can find out a bit more about the different co-ordinate systems on the matplotlib documentation.
Styling Artists#
Most plotting methods have styling options for the Artists, accessible either
when a plotting method is called, or from a “setter” on the Artist. In the
plot below we manually set the color, linewidth, and linestyle of the
Artists created by Axes.plot()
, and we set the linestyle of the second line
after the fact with Line2D.set_linestyle()
.
fig, ax = plt.subplots()
x = np.arange(len(data1))
ax.plot(x, np.cumsum(data1), color="blue", linewidth=3, linestyle="--")
(l,) = ax.plot(x, np.cumsum(data2), color="orange", linewidth=2)
l.set_linestyle(":");
Multiple Charts on one Figure#
One basic bit of functionality that you might need is to put more than one type of information on a single overall figure. This could mean having multiple subplots or it could be about putting more information on a single set of axes.
Mutiples features on one set of axes#
This is the kind of task where matplotlib’s build-what-you-want philosophy starts to win out. To add another feature on an ax
that you’ve already created is as simple as calling ax.<method>
again. You can add as many features as you like.
In the example below, we’ll call ax.hist()
followed by ax.plot()
to get a the theoretical curve for a normal distribution (aka Gaussian) overlaid on a kernel density estimate based on many draws from the relevant distribution using numpy.
While we’re at it, let’s add an equation that describes the theoretical curve.
rand_draws = prng.standard_normal(5000)
grid_x = np.linspace(-5, 5, 1000)
fig, ax = plt.subplots()
ax.hist(rand_draws, bins=50, density=True, label="Data")
ax.plot(
grid_x,
1 / np.sqrt(2 * np.pi) * np.exp(-(grid_x**2) / 2),
linewidth=4,
label=r"$\frac{1}{2\pi}e^{-\frac{x^2}{2}}$",
)
ax.legend(fontsize=14);
Facets#
Another way to put more information on a single figure is to have different facets. To illustrate some of the ideas we’re about to see, it’s going to be useful to have some data. Let’s pull down GDP per capita for a selection of countries.
from pandas_datareader import wb
start_year = 2000
end_year = 2022
df_cc_gdp = wb.download(
indicator=["NY.GDP.PCAP.KD"], # GDP per capita in 2015 USD
country=["US", "GB", "FR", "DE", "IT", "JP"],
start=start_year,
end=end_year,
)
df_cc_gdp = df_cc_gdp.reset_index() # drop country as index (for plots)
df_cc_gdp["year"] = df_cc_gdp["year"].astype("int") # ensure year is a number
# sort by year, then country (important for next step)
df_cc_gdp = df_cc_gdp.sort_values(["year", "country"])
# index to that country's value in first entry by country
df_cc_gdp["GDP per capita"] = (
100
* df_cc_gdp["NY.GDP.PCAP.KD"]
/ df_cc_gdp.groupby("country")["NY.GDP.PCAP.KD"].transform(lambda row: row.iloc[0])
)
df_cc_gdp.head()
country | year | NY.GDP.PCAP.KD | GDP per capita | |
---|---|---|---|---|
45 | France | 2000 | 33592.466830 | 100.0 |
22 | Germany | 2000 | 34490.075769 | 100.0 |
91 | Italy | 2000 | 32350.904367 | 100.0 |
114 | Japan | 2000 | 31430.631130 | 100.0 |
68 | United Kingdom | 2000 | 38918.455057 | 100.0 |
Subplots#
The first way to put data on multiple charts that are part of the same overall figure is to use the built-in subplot function. plt.subplots()
accepts arguments for nrows=
and ncols=
that we can specify the number of figures with. We have six countries so let’s do two rows of three columns.
We can use a for
loop to go over these. Actually, though, the structure of the axes
object that comes back from plt.subplots()
is a 2x3 matrix (or array), and you can’t loop over that. But we can loop over a “flattened” version of it, which you can create with axes.flatten()
.
Once we’re in the loop, we subset the data by country for each loop and add it to the chart using ax.plot()
. Finally, we’ll use the same limits for every subplot to make the chart easier to read.
Let’s see this with the data we just pulled down.
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
country = df_cc_gdp["country"].unique()[i]
country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
ax.plot(
country_df["year"],
country_df["GDP per capita"],
lw=3,
color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
)
ax.set_title(country, loc="center", fontsize=13)
ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year, end_year)
ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
A nice extra you can do with this sort of chart is to add the other countries but greyed out so that it’s clear which country is featured but the cross-country comparison is a bit easier.
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flatten()):
country = df_cc_gdp["country"].unique()[i]
# grab the other countries
other_countries = [x for x in df_cc_gdp["country"].unique() if x != country]
for other in other_countries:
o_country_df = df_cc_gdp.loc[df_cc_gdp["country"] == other, :]
ax.plot(
o_country_df["year"],
o_country_df["GDP per capita"],
lw=1,
color="k",
alpha=0.1,
)
country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
ax.plot(
country_df["year"],
country_df["GDP per capita"],
lw=3,
color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
)
ax.set_title(country, loc="center", fontsize=13)
ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year, end_year)
ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max())
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
Gridspec#
This is a more complicated way of creating multiple subplots, but it is more flexible too: although it still relies on an underlying grid structure, the units of that grid don’t have to all equate to a single subplot. It’s easier to see than to describe!
Here’s an example that shows what’s possible:
Let’s say we wanted to really feature a single country more than the others: we might want to give it a full height spot. We can do this with gridspec
.
Taking our example, let’s “focus” on the UK. We’re going to use the layout above, with one corner having the UK in and covering the grid from units 1 to 2 on both dimensions. Meanwhile, positions (0, 0), (0, 1), (1, 0), (2, 0), and (0, 2) will have the other five countries on.
We can produce this pattern using code: first a double list comprehension to pick out those specific values, then reducing it to a single list of co-ordinates (because double list comprehensions lead to nested lists). Here’s the set of positions for the other five countries in code:
import itertools
nested_list = [
[(i, j) for i in [0, 1, 2] if (i != j or i == 0) and (i + j <= 2)]
for j in [0, 1, 2]
]
other_country_indices = list(itertools.chain(*nested_list))
other_country_indices
[(0, 0), (1, 0), (2, 0), (0, 1), (0, 2)]
It’s a bit labourious to do this in code, but it’s often good practice to think about how to solve the general problem rather than manually typing out the solution because the former scales better to other problems. As you code more, you’ll find yourself writing these little helper lines again and again in different contexts. GitHub even has a service called gists to help you keep track of code snippets that may be useful across lots of projects.
Now let’s convert our grid into some plots!
import matplotlib.gridspec as gridspec
fig = plt.figure(constrained_layout=True)
nrows, ncols = 3, 3
gspec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)
# create a list of the five spots we don't want to feature UK on
non_uk_axes = [fig.add_subplot(gspec[i, j]) for i, j in other_country_indices]
uk_ax = fig.add_subplot(gspec[1:, 1:])
all_axes = non_uk_axes + [uk_ax]
# Plot the other countries
o_countries_df = df_cc_gdp.loc[df_cc_gdp["country"] != "United Kingdom", :]
o_country_names = o_countries_df["country"].unique()
for i, ax in enumerate(non_uk_axes):
# grab the other country's data
o_country = o_country_names[i]
o_country_df = o_countries_df.loc[o_countries_df["country"] == o_country, :]
ax.plot(
o_country_df["year"],
o_country_df["GDP per capita"],
lw=3,
color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
)
ax.set_title(o_country, loc="center", fontsize=13)
# Now do UK
uk_df = df_cc_gdp.loc[df_cc_gdp["country"] == "United Kingdom", :]
uk_ax.set_title("United Kingdom", loc="center", fontsize=15)
uk_ax.plot(
uk_df["year"],
uk_df["GDP per capita"],
lw=3,
color=plt.rcParams["axes.prop_cycle"].by_key()["color"][len(o_country_names)],
)
# global settings
for ax in all_axes:
ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year, end_year)
ax.set_ylim(100, 150)
plt.show()
You should note, though, that this kind of chart can be quite confusing and—generally—it’s better to have consistent axes and subplot sizes if you’re trying to make comparisons easy.
Using Subplot Mosaic to Specify a Layout#
We told you that the matplotlib API was ferociously powerful. Well, subplot grids and GridSpec aren’t the only ways to build up multiple subplots: you can build a layout of Axes based on ASCII art or nested lists! The method that does this, subplot_mosaic()
, is a helper function to build complex GridSpec layouts visually.
Here’s an example of using text to specify a layout:
axd = plt.figure(constrained_layout=True).subplot_mosaic(
"""
ABD
CCD
CC.
"""
)
kw = dict(ha="center", va="center", fontsize=60, color="darkgrey")
for k, ax in axd.items():
ax.text(0.5, 0.5, k, transform=ax.transAxes, **kw)
Note how we did a little trick with the font settings too: we put them in a keyword argument dictionary called kw
and passed these into the ax.text()
function with the “splatty-splat” operator, which unpacks them.
Let’s now see another way to use subplot_mosiac()
, this time using lists to achieve a particular layout:
ax_dict = plt.figure(constrained_layout=True).subplot_mosaic(
[["A", "B", "C", "C", "D", "D"]]
)
kw = dict(ha="center", va="center", fontsize=30, color="darkgrey")
for k, ax in ax_dict.items():
ax.text(0.5, 0.5, k, transform=ax.transAxes, **kw)
An Advanced Example: Zoomed Plots#
This is taken from the online book Scientific Visualization in Matplotlib [Rougier, 2021]. It’s quite a complex example, but does show what can be achieved.
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Rectangle
from matplotlib.patches import ConnectionPatch
fig = plt.figure(figsize=(5, 4))
n = 5
gs = GridSpec(n, n + 1)
ax = plt.subplot(
gs[:n, :n], xlim=[-1, +1], xticks=[], ylim=[-1, +1], yticks=[], aspect=1
)
X = prng.normal(0, 0.35, 1000)
Y = prng.normal(0, 0.35, 1000)
ax.scatter(X, Y, edgecolor="None", facecolor="C1", alpha=0.5)
I = prng.choice(len(X), size=n, replace=False)
Px, Py = X[I], Y[I]
I = np.argsort(Y[I])[::-1]
Px, Py = Px[I], Py[I]
ax.scatter(Px, Py, edgecolor="black", facecolor="None", linewidth=0.75)
dx, dy = 0.075, 0.075
for i, (x, y) in enumerate(zip(Px, Py)):
sax = plt.subplot(
gs[i, n],
xlim=[x - dx, x + dx],
xticks=[],
ylim=[y - dy, y + dy],
yticks=[],
aspect=1,
)
sax.scatter(X, Y, edgecolor="None", facecolor="C1", alpha=0.5)
sax.scatter(Px, Py, edgecolor="black", facecolor="None", linewidth=0.75)
sax.text(
1.1,
0.5,
"Point " + chr(ord("A") + i),
rotation=90,
size=8,
ha="left",
va="center",
transform=sax.transAxes,
)
rect = Rectangle(
(x - dx, y - dy),
2 * dx,
2 * dy,
edgecolor="black",
facecolor="None",
linestyle="--",
linewidth=0.75,
)
ax.add_patch(rect)
con = ConnectionPatch(
xyA=(x, y),
coordsA=ax.transData,
xyB=(0, 0.5),
coordsB=sax.transAxes,
linestyle="--",
linewidth=0.75,
patchA=rect,
arrowstyle="->",
)
fig.add_artist(con)
Working with Times and Dates#
Python has really amazing support for “datetimes”, and you can find out more about the general capabilities in Introduction to Time.
It’s also very important to know how to make charts with times and dates in economics, so we’re going to give this one some special attention.
To demonstrate some of the principles, we’re going to work with real data on UK and US GDP growth. First let’s pull in that data and take a look at it
from pandas_datareader import wb
from datetime import datetime
ts_start_date = pd.to_datetime("1999-01-01")
ts_end_date = datetime.now()
countries = ["GBR", "USA"]
gdf_const_2015_usd_code = 'NY.GDP.MKTP.KD'
df = wb.download(indicator=gdf_const_2015_usd_code, country=countries, start=ts_start_date, end=ts_end_date).reset_index()
df = df.sort_values(by="year") # this is very important!
# note that we are computing the percent change on a by-country basis!
# Otherwise, you'd mix up the UK and US values
df["growth, %"] = df.groupby("country")[gdf_const_2015_usd_code].transform(lambda x: 100*x.pct_change(1))
df = df.reset_index(drop=True)
df.head()
country | year | NY.GDP.MKTP.KD | growth, % | |
---|---|---|---|---|
0 | United Kingdom | 1999 | 2.197118e+12 | NaN |
1 | United States | 1999 | 1.321548e+13 | NaN |
2 | United States | 2000 | 1.375430e+13 | 4.077159 |
3 | United Kingdom | 2000 | 2.292006e+12 | 4.318727 |
4 | United Kingdom | 2001 | 2.351111e+12 | 2.578741 |
Now, you can see that this is annual data, and what’s been returned by the World Bank has given us a “year” field that is of type “object”, basically the lowest common denominator data type:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48 entries, 0 to 47
Data columns (total 4 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 country 48 non-null object
1 year 48 non-null object
2 NY.GDP.MKTP.KD 48 non-null float64
3 growth, % 46 non-null float64
dtypes: float64(2), object(2)
memory usage: 1.6+ KB
We’re going to turn it back into a datetime. This is a special type that helps with dates and times. In the below code snippet, the year we get from running pd.to_datetime()
defaults to the 1st of January. But really, these dates represent the year that has already passed, so we’ll end on a time offset.
df["year"] = pd.to_datetime(df["year"], format="%Y") + pd.offsets.BYearEnd()
df.head()
country | year | NY.GDP.MKTP.KD | growth, % | |
---|---|---|---|---|
0 | United Kingdom | 1999-12-31 | 2.197118e+12 | NaN |
1 | United States | 1999-12-31 | 1.321548e+13 | NaN |
2 | United States | 2000-12-29 | 1.375430e+13 | 4.077159 |
3 | United Kingdom | 2000-12-29 | 2.292006e+12 | 4.318727 |
4 | United Kingdom | 2001-12-31 | 2.351111e+12 | 2.578741 |
Okay, we’re now going to do a couple of other things that will help matplotlib. Unlike, say, lets-plot, matplotlib expects each line of interest in your data to appear across columns, ie in wide format. So we’re going to transform our data to be wide, and make the datetime column, "year"
, into the index.
df_wide = df.pivot(index="year", values="growth, %", columns="country")
df_wide = df_wide.sort_index()
df_wide.head()
country | United Kingdom | United States |
---|---|---|
year | ||
1999-12-31 | NaN | NaN |
2000-12-29 | 4.318727 | 4.077159 |
2001-12-31 | 2.578741 | 0.954339 |
2002-12-31 | 1.791986 | 1.695943 |
2003-12-31 | 3.146573 | 2.796209 |
We also set the index to be the year, beause
By running df_wide.info()
, we can learn what data types we now have.
df_wide.info()
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 24 entries, 1999-12-31 to 2022-12-30
Data columns (total 2 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 United Kingdom 23 non-null float64
1 United States 23 non-null float64
dtypes: float64(2)
memory usage: 576.0 bytes
Now, note that i) the index is composed of dates, and ii) those dates have datatype “DatetimeIndex”.
Alright, let’s just make a plot here without worrying too much about what might happen. We’ll do a line chart showing the growth rates over time of these two large economies.
fig, ax = plt.subplots()
df_wide.plot(ax=ax, lw=3)
ax.set_title("GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right() # Put tick marks and tick labels on right-hand side
You’ll notice straight away that because we began with a datetime index, when we called the convenience function df.plot(ax=...)
, we automatically got an x-axis that understands and renders datetimes. More than that, the scale is sensible: it’s only showing years on the major tick labels, even though we have quarterly data.
Also note that the data are in wide format, which matplotlib prefers.
We could achieve the same outcome without using pandas built-in plotting methods, though it’s a tad more verbose (we need two separate commands for the two countries and to explicitly add a legend):
fig, ax = plt.subplots()
for name in ["United Kingdom", "United States"]:
ax.plot(df_wide.index, df_wide[name], lw=3, label=name)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend()
ax.yaxis.tick_right() # Put tick marks and tick labels on right-hand side
Now let’s see what happens if we change the axis limits. Because we’re using datetime data type for the x-axis variable, we’ll need to pass a datetime in to the limits too.
fig, ax = plt.subplots()
df_wide.plot(ax=ax, lw=3)
ax.set_title("Real GDP growth, %", loc="right")
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.tick_right() # Put tick marks and tick labels on right-hand side
ax.set_xlim(pd.to_datetime("2018-01-01"), pd.to_datetime("2021-01-01"));
As you can see, the plot dynamically responded to the shorter time period by putting more details in, here of quarters (as the minor ticks). You can specify exactly what you want with the tick label formatters that cater to datetimes, but the defaults are pretty well-behaved.
Advanced Topics#
Styles#
It can feel a bit tedious to apply different styles to every single ax, subplot, and even figure! Fortunately, there are ways to make applying styles more efficient. In this section, we’ll see how to do this concisely.
There are essentially four ways to apply styles:
modify each
ax
object one-by-one (works at the axis level)use a dictionary once to pass keyword arguments to all objects of a similar type (works at the axis level)
temporarily change the global style (works at the figure level)
change the global style (works on every figure!)
We’ve seen plenty of 1. already, so we won’t do more on that. Let’s home in on 2. first.
Consistent Styling Across Axes and Subplots using Dictionaries#
Let’s say we have multiple text annotations that we want to make consistent in most details. We can use the cross-country GDP per capita example from earlier. Essentially, we define a dictionary of all the properties (and here, even a nested dictionary with further keyword arguments), and pass this using the splaty-splat (**
) operator.
kw_settings = dict(
xycoords="data",
ha="center",
va="center",
fontweight="bold",
fontsize=12,
xytext=(-20, +50),
textcoords="offset points",
arrowprops=dict(
arrowstyle="->",
connectionstyle="arc3,rad=-0.3",
alpha=0.7,
),
bbox=dict(facecolor="white", edgecolor="None", alpha=0.85),
)
# spread out x text annotations for each country
years_array = np.arange(start_year + 1, end_year, len(df_cc_gdp["country"].unique()) - 2)
# get countries in order of lowest to highest end value
countries = (
df_cc_gdp.sort_values("GDP per capita")
.loc[df_cc_gdp["year"] == df_cc_gdp["year"].max(), "country"]
.values
)
fig, ax = plt.subplots()
for i, country in enumerate(countries):
country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
ax.plot(
country_df["year"],
country_df["GDP per capita"],
lw=3,
color=plt.rcParams["axes.prop_cycle"].by_key()["color"][i],
)
ax.annotate(
country,
(
years_array[i],
country_df.loc[country_df["year"] == years_array[i], "GDP per capita"],
),
**kw_settings,
color=plt.gca().lines[-1].get_color(),
)
ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year - 1, end_year)
ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max() * 1.05)
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
/Users/aet/mambaforge/envs/codeforecon/lib/python3.10/site-packages/matplotlib/text.py:1463: FutureWarning: Calling float on a single element Series is deprecated and will raise a TypeError in the future. Use float(ser.iloc[0]) instead
y = float(self.convert_yunits(y))
Temporarily Change the Global Style#
Python has a powerful notion called contexts that can change the apparent global settings to any code run within them. You can identify them by the with
command. Now, we’ll show how to use these to completely change the style of a plot.
The first, simple example is
data = prng.standard_normal(50)
fig, ax = plt.subplots()
ax.plot(data)
plt.show()
Now let’s say, just for this plot, we want to make a change. We can do this by setting the “rc context”. This changes global settings that relate to figures. You can pass keyword arguments to it as dictionary items. In this case, we’ll change the linewidth and the linestyle. The with
command means any indented code following the with
gets changed. But, after the indented block is over, everything reverts back to the usual settings.
You can find a full list of the settings you can pass here.
import matplotlib as mpl
with mpl.rc_context({'lines.linewidth': 4, 'lines.linestyle': ':'}):
fig, ax = plt.subplots()
ax.plot(data)
plt.show()
There are some built-in styles that you can take advantage of for this too. These are called via plt.style.context(<style name>)
, and you can find a list here.
with plt.style.context('grayscale'):
fig, ax = plt.subplots()
for i, country in enumerate(countries):
country_df = df_cc_gdp.loc[df_cc_gdp["country"] == country, :]
ax.plot(
country_df["year"],
country_df["GDP per capita"],
lw=3,
)
ax.annotate(
country,
(
years_array[i],
country_df.loc[country_df["year"] == years_array[i], "GDP per capita"],
),
**kw_settings,
color=plt.gca().lines[-1].get_color(),
)
ax.yaxis.tick_right()
ax.spines["right"].set_visible(True)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.set_xlim(start_year - 1, end_year)
ax.set_ylim(df_cc_gdp["GDP per capita"].min(), df_cc_gdp["GDP per capita"].max() * 1.05)
fig.suptitle(f"GDP per capita (indexed to 100 in {start_year})", fontsize=15)
plt.tight_layout();
/Users/aet/mambaforge/envs/codeforecon/lib/python3.10/site-packages/matplotlib/text.py:1463: FutureWarning: Calling float on a single element Series is deprecated and will raise a TypeError in the future. Use float(ser.iloc[0]) instead
y = float(self.convert_yunits(y))
Setting the Style for an Entire Script or Notebook#
This book uses a special style file to set its RC params and over-ride the default values found in the default rc file. Once loaded, using
plt.style.use(
"plot_style.txt"
)
all charts use whatever settings are in “plot_style.txt” unless those settings are explicitly over-ridden by other commands.
You can see the first few lines of the “plot_style.txt” file used for this book below, or you can browse it yourself here.
!head -10 plot_style.txt
xtick.color: 323034
ytick.color: 323034
text.color: 323034
lines.markeredgecolor: black
patch.facecolor : bc80bd
patch.force_edgecolor : True
patch.linewidth: 0.8
scatter.edgecolors: black
grid.color: b1afb5
axes.titlesize: 19
Review#
That concludes our tour of the basics of matplotlib. You should now feel comfortable making basic charts using data and
You can find much more advanced uses of it over at the matplotlib documentation and in the online book Scientific Visualization in Matplotlib [Rougier, 2021].