import warnings
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.patches import Circle, Polygon, Rectangle
# Ignore warnings just for this chapter
warnings.filterwarnings("ignore")Narrative Data Visualisation
Introduction
In this chapter, we’ll examine the tricks and techniques of narrative data visualisation. This type of visualisation gets a lot more press than the others discussed in Intro to Data Visualisation ; literally, because journalists use it in their work. There are many books written on it too.
For narrative visualisation, it’s particularly helpful to bear this quote in mind:
The purpose of visualisation is insight, not pictures
—Ben Shneiderman, populariser of the highlighted text link
Narrative data visualisation requires the most thought in the step where you go from the first view to the end product. It’s a visualisation that doesn’t just show a picture, but gives an insight.
Let’s import the packages we’ll need for the rest of the chapter.
Narrative Data Visualisation
As discussed, the name of the game here is to communicate a particular narrative. Let’s see what tricks we can use to help do this. The three stages of this are: deciding the story, deciding the chart, and creating a chart that helps deliver that narrative.
Deciding the story
We are only human, and a digestible narrative goes a lot further than lots of data with no thread running through. If you are doing narrative visualisation, you must first be clear about the story you want to tell and why it’s important. Let’s say you are making a visualisation on topic ‘Y’, then some reasons why ‘Y’ might be important that you may want to think about when creating your narrative are:
- Y matters: when Y rises or falls, people are hurt or helped.
- Y is puzzling: it defies easy explanation.
- Y is controversial or political: some argue one thing while others say another.
- Y is big (like the service sector) or common (like traffic jams).
- Y helps someone do something they could not do before.
If you can identify why a story is important, you’re half way to designing a narrative visualisation that brings the story into focus. Later on, we’ll recreate a chart from the Financial Times. In that case, ‘Y’ is a high for air pollution in Beijing that hurts people (it’s also political because of efforts to tackle the problem). So, in this case, the creator of the narrative needs to convey that the pollution has hit a high (presumably relative to other points in time) and that this high is far above safe levels.
What plot should I use?
Once you know what story you want to tell, you need the right kind of chart for the job. Resources like the Financial Times’ visual vocabulary are extremely useful here. You need to ask yourself what element you’re trying to highlight: a point in time, the size relative to other units cross-sectionally, the distribution either in numbers or spatially, the difference between groups, how something has changed, etc.? There are charts that can help with all of these and it’s well worth looking at the link to get a sense.
Drawing Attention to Enhance Narrative Visualisation
According to data visualisation master Jon Schwabish’s book Better Data Visusalizations (Schwabish 2021), there are 15 ways to draw an audience’s attention in a chart:
- Shape
- Enclosure
- Line width
- Saturation
- Colour
- Size
- Markings
- Orientation
- Position
- Sharpness
- Length
- 3D
- Curvature
- Density
- Closure
Sometimes people add a 16th entry to this list, Connection.
square_color = "#bc80bd"
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Varta"
class VisualEncodingChart:
def __init__(self, figsize=(8.4, 6)):
self.fig, self.axes = plt.subplots(3, 5, figsize=figsize)
self.fig.patch.set_facecolor("white")
self.setup_axes()
def setup_axes(self):
"""Remove ticks and spines from all axes"""
for ax in self.axes.flat:
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
def add_title(self, ax, title, highlight=False):
"""Add title to subplot with optional highlight"""
if highlight:
# Add orange border around title
bbox_props = dict(
boxstyle="round,pad=0.3",
facecolor="none",
edgecolor="orange",
linewidth=2,
)
ax.text(
0.5,
1.15,
title,
ha="center",
va="center",
transform=ax.transAxes,
fontsize=14,
fontweight="bold",
bbox=bbox_props,
)
else:
ax.text(
0.5,
1.1,
title,
ha="center",
va="center",
transform=ax.transAxes,
fontsize=14,
fontweight="bold",
)
def create_shape_panel(self):
"""Create the Shape panel"""
ax = self.axes[0, 0]
self.add_title(ax, "Shape")
# Row 1: squares
for i in range(3):
rect = Rectangle(
(0.15 + i * 0.25, 0.7),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
# Row 2: squares and cross
for i in range(2):
rect = Rectangle(
(0.15 + i * 0.25, 0.4),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
# Cross
cross_x = 0.15 + 2 * 0.25 + 0.075
cross_y = 0.4 + 0.075
# Vertical bar
rect1 = Rectangle(
(cross_x - 0.02, cross_y - 0.075),
0.04,
0.15,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect1)
# Horizontal bar
rect2 = Rectangle(
(cross_x - 0.075, cross_y - 0.02),
0.15,
0.04,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect2)
# Row 3: squares
for i in range(3):
rect = Rectangle(
(0.15 + i * 0.25, 0.1),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
def create_enclosure_panel(self):
"""Create the Enclosure panel with orange border"""
ax = self.axes[0, 1]
self.add_title(ax, "Enclosure", highlight=True)
# All squares with only one having orange border
positions = [
(0.15, 0.7),
(0.4, 0.7),
(0.65, 0.7),
(0.15, 0.4),
(0.4, 0.4),
(0.65, 0.4),
(0.15, 0.1),
(0.4, 0.1),
(0.65, 0.1),
]
for i, (x, y) in enumerate(positions):
rect = Rectangle(
(x, y), 0.15, 0.15, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
# Add orange border to only one square (middle of bottom row)
if i == 7: # Middle square in bottom row
border = Rectangle(
(x - 0.02, y - 0.02),
0.19,
0.19,
facecolor="none",
edgecolor="orange",
linewidth=2,
)
ax.add_patch(border)
def create_line_width_panel(self):
"""Create the Line Width panel"""
ax = self.axes[0, 2]
self.add_title(ax, "Line Width")
# Three rows of rectangles with only one having different line width
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
# Only middle square in bottom row has thick line
linewidth = 3 if (row == 2 and col == 1) else 1
rect = Rectangle(
(x, y),
0.15,
0.15,
facecolor="none",
edgecolor=square_color,
linewidth=linewidth,
)
ax.add_patch(rect)
def create_saturation_panel(self):
"""Create the Saturation panel"""
ax = self.axes[0, 3]
self.add_title(ax, "Saturation")
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
if row == 1 and col == 0:
alpha = 1
else:
alpha = 0.5
rect = Rectangle(
(x, y),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
alpha=alpha,
)
ax.add_patch(rect)
def create_color_panel(self):
"""Create the Color panel"""
ax = self.axes[0, 4]
self.add_title(ax, "Colour")
# Different colors
colors = [
["#fb8072", square_color, square_color], # Orange, blue, blue
[square_color, square_color, square_color], # All blue
[square_color, square_color, square_color],
] # All blue
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
rect = Rectangle(
(x, y), 0.15, 0.15, facecolor=colors[row][col], edgecolor="none"
)
ax.add_patch(rect)
def create_size_panel(self):
"""Create the Size panel"""
ax = self.axes[1, 0]
self.add_title(ax, "Size")
# All squares same size except one
for row in range(3):
for col in range(3):
# Only middle square in bottom row is smaller
if row == 2 and col == 1:
size = 0.08
x = (
0.15 + col * 0.25 + (0.15 - size) / 2
) # Center the smaller rectangle
y = 0.7 - row * 0.3 + (0.15 - size) / 2
else:
size = 0.15
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
rect = Rectangle(
(x, y), size, size, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
def create_markings_panel(self):
"""Create the Markings panel"""
ax = self.axes[1, 1]
self.add_title(ax, "Markings")
# Most squares are plain, one has an X
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
rect = Rectangle(
(x, y), 0.15, 0.15, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
# Add X to second row, first column
if row == 1 and col == 1:
# Draw X
ax.plot(
[x + 0.03, x + 0.12], [y + 0.03, y + 0.12], "white", linewidth=3
)
ax.plot(
[x + 0.12, x + 0.03], [y + 0.03, y + 0.12], "white", linewidth=3
)
def create_orientation_panel(self):
"""Create the Orientation panel"""
ax = self.axes[1, 2]
self.add_title(ax, "Orientation")
# Most squares are upright, one is rotated
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25 + 0.075 # Center point
y = 0.7 - row * 0.3 + 0.075
rect = Rectangle(
(x - 0.075, y - 0.075),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
angle=45 if row == 1 and col == 1 else 0,
rotation_point="center",
)
ax.add_patch(rect)
def create_position_panel(self):
"""Create the Position panel"""
ax = self.axes[1, 3]
self.add_title(ax, "Position")
# Squares in different positions
positions = [
(0.15, 0.7),
(0.4, 0.7),
(0.65, 0.7),
(0.15, 0.4),
(0.5, 0.4), # Middle one shifted right
(0.15, 0.1),
(0.4, 0.1),
(0.65, 0.1),
]
for i, (x, y) in enumerate(positions):
if i == 4: # Middle square in second row
x = 0.5 # Shift it right
rect = Rectangle(
(x, y), 0.15, 0.15, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
def create_3d_panel(self):
"""Create the 3D panel"""
ax = self.axes[1, 4]
self.add_title(ax, "3D")
# Most squares are flat, one has 3D appearance
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
if row == 2 and col == 2: # 3D cube
# Draw cube faces
# Front face
rect1 = Rectangle(
(x, y), 0.12, 0.12, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect1)
# Top face (parallelogram)
top = np.array(
[
[x, y + 0.12],
[x + 0.03, y + 0.15],
[x + 0.15, y + 0.15],
[x + 0.12, y + 0.12],
]
)
top_face = Polygon(top, facecolor="#34495e", edgecolor="none")
ax.add_patch(top_face)
# Right face
right = np.array(
[
[x + 0.12, y],
[x + 0.15, y + 0.03],
[x + 0.15, y + 0.15],
[x + 0.12, y + 0.12],
]
)
right_face = Polygon(right, facecolor="#1a252f", edgecolor="none")
ax.add_patch(right_face)
else:
rect = Rectangle(
(x, y), 0.15, 0.15, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
def create_length_panel(self):
"""Create the Length panel"""
ax = self.axes[2, 0]
self.add_title(ax, "Length")
# Horizontal bars of different lengths
lengths = [
[0.2, 0.2, 0.2],
[0.2, 0.2, 0.2],
[0.25, 0.15, 0.2],
] # Bottom row varies
for row in range(3):
for col in range(3):
length = lengths[row][col]
x = 0.15 + col * 0.25
y = 0.75 - row * 0.3
rect = Rectangle(
(x, y), length, 0.05, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
def create_curvature_panel(self):
"""Create the Curvature panel"""
ax = self.axes[2, 1]
self.add_title(ax, "Curvature")
# Straight lines with only one curved
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.75 - row * 0.3
# Only middle square in bottom row has curve
if row == 2 and col == 1:
# Create curved line using arc
theta = np.linspace(0, np.pi, 20)
curve_x = x + 0.1 + 0.08 * np.cos(theta)
curve_y = y + 0.025 + 0.02 * np.sin(theta)
ax.plot(curve_x, curve_y, color=square_color, linewidth=3)
else:
# Straight line
rect = Rectangle(
(x, y), 0.2, 0.05, facecolor=square_color, edgecolor="none"
)
ax.add_patch(rect)
def create_density_panel(self):
"""Create the Density panel"""
ax = self.axes[2, 2]
self.add_title(ax, "Density")
# Different densities of dots, all same size
dot_size = 0.06 # Fixed size for all dots
for row in range(3):
y_base = 0.7 - row * 0.3
if row == 0: # Top row: 2x3 grid (6 dots)
for i in range(2):
for j in range(3):
x = 0.2 + j * 0.15
y = y_base + 0.05 + i * 0.08
rect = Rectangle(
(x, y),
dot_size,
dot_size,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
elif row == 1: # Middle row: 2x2 grid (4 dots)
for i in range(2):
for j in range(2):
x = 0.25 + j * 0.2
y = y_base + 0.05 + i * 0.08
rect = Rectangle(
(x, y),
dot_size,
dot_size,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
else: # Bottom row: 1x3 grid (3 dots)
for j in range(3):
x = 0.2 + j * 0.2
y = y_base + 0.075
rect = Rectangle(
(x, y),
dot_size,
dot_size,
facecolor=square_color,
edgecolor="none",
)
ax.add_patch(rect)
def create_closure_panel(self):
"""Create the Closure panel"""
ax = self.axes[2, 3]
self.add_title(ax, "Closure")
# Different levels of closure - make open ones more obviously open
for row in range(3):
for col in range(3):
x = 0.2 + col * 0.2
y = 0.7 - row * 0.3 + 0.075
if row == 2 and col == 1: # Complete circle
circle = Circle(
(x, y),
0.06,
facecolor="none",
edgecolor=square_color,
linewidth=2,
)
ax.add_patch(circle)
else: # More open Cs - larger gap
# Draw arc with bigger opening (3/4 circle)
theta = np.linspace(0.5, 2 * np.pi - 0.5, 40) # Bigger gap
arc_x = x + 0.06 * np.cos(theta)
arc_y = y + 0.06 * np.sin(theta)
ax.plot(arc_x, arc_y, color=square_color, linewidth=2)
def create_sharpness_panel(self):
"""Create the Sharpness panel"""
ax = self.axes[2, 4]
self.add_title(ax, "Sharpness")
# Only one square is blurred, rest are sharp
for row in range(3):
for col in range(3):
x = 0.15 + col * 0.25
y = 0.7 - row * 0.3
# Only middle square in bottom row is blurred
if row == 2 and col == 1:
# Create multiple overlapping rectangles for blur effect
for i, offset in enumerate(np.linspace(-0.01, 0.01, 5)):
rect = Rectangle(
(x + offset, y + offset),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
alpha=0.15,
) # Lower alpha for each layer
ax.add_patch(rect)
else:
# Sharp rectangle
rect = Rectangle(
(x, y),
0.15,
0.15,
facecolor=square_color,
edgecolor="none",
alpha=1.0,
)
ax.add_patch(rect)
def create_chart(self):
"""Create the complete visualization"""
# Create all panels
self.create_shape_panel()
self.create_enclosure_panel()
self.create_line_width_panel()
self.create_saturation_panel()
self.create_color_panel()
self.create_size_panel()
self.create_markings_panel()
self.create_orientation_panel()
self.create_position_panel()
self.create_3d_panel()
self.create_length_panel()
self.create_curvature_panel()
self.create_density_panel()
self.create_closure_panel()
self.create_sharpness_panel()
# Adjust layout
plt.tight_layout()
plt.subplots_adjust(hspace=0.4, wspace=0.3)
return self.fig
# Create and display the chart
chart = VisualEncodingChart()
fig = chart.create_chart()
plt.show()But be warned: not all of these are equivalent! It’s much easier to perceive differences in length than it is differences in, say, volume. So if you want your audience to be able to make comparisons or quantitative assessments, you need to pick what techniques you use from this list carefully. Roughly in order of how easy they are to perceive quantitatively, the features are: one common axis, two axes, length, slope, angle, parts of whole, area, volume, saturation, and hue.
Case Studies in Narrative Visualisation
Good visualisation helps the viewer to grasp the narrative—rather than leaving them puzzling as to what the key message is. To that end, various adornments may be added to a plot to bring out the narrative. These adornments typically take the form of those we have seen already that draw the eye, but for a successful narrative visualisation they must come together to tell a story.
We’ll also make use of some other tricks:
Text annotations, which can be a useful addition to a chart because they further enhance the narrative.
Declutter the graph, removing lines that aren’t helping frame the story
Use the title to tell the story, and put the y-axis label horizontally below
Use faded text for text that isn’t contributing directly to the narrative
If there are multiple lines, label them directly rather than via a legend
The Financial Times
Let’s see an example that brings together quite a few of these elements, recreating a chart from the Financial Times: a newspaper that is well-known for its impressive visualisations. The chart tells the story of extremely high levels of air pollution in Beijing at the start of 2021. (Note that the data here disagree with the original Financial Times source, which were unavailable; do not take the numbers too seriously.)
Let’s first grab the data:
df = pd.read_csv(
"https://github.com/aeturrell/coding-for-economists/raw/main/data/beijing_pm.csv",
)
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
# Restrict to time scale of interest
df = df[(df.index >= "2020-02-28") & (df.index <= "2021-03-01")]
df.head()| pm25 | |
|---|---|
| date | |
| 2020-02-28 | 102.857143 |
| 2020-02-29 | 103.285714 |
| 2020-03-01 | 123.571429 |
| 2020-03-02 | 118.142857 |
| 2020-03-03 | 108.142857 |
Now let’s get on to the figure. The code that generates individual parts of the chart is annotated to explain what it’s doing.
# set a level for fading some elements of the plot
fade_alpha = 0.8
fig, ax = plt.subplots(figsize=(9, 5))
# Plot a line
ax.plot(df.index, df["pm25"], lw=1.5, color="#12549a")
# Title that gives the narrative
plt.suptitle(
"Severe dust storm causes sharp rise in Beijing's air pollution",
size=16,
ha="left",
x=0.12,
)
# Horizontal y-axis title, faded
ax.set_title(
"PM2.5 rolling 7-day average (micrograms per cubic metre)",
loc="left",
size=14,
alpha=fade_alpha,
)
# Time is obvious, so no x-label needed: instead annotate sources
ax.set_xlabel(
"* Based on annual mean exposure of 10 micrograms per cubic metre \n Source: AQICN",
loc="left",
size=9,
alpha=fade_alpha,
)
# remove chart clutter
for key, spine in ax.spines.items():
spine.set_visible(False)
ax.tick_params(axis="y", which="both", length=0)
ax.tick_params(axis="x", which="both", color=[1, 0, 0, fade_alpha])
# set aesthetically pleasing limits
ax.set_ylim(0, 200)
ax.set_xlim(None, df.index.max())
# for time series, tick marks on the right help give a sense of right-ward motion
ax.yaxis.tick_right()
# create grid only in y-direction, so viewer can judge level-but with few ticks
ax.yaxis.set_major_locator(ticker.MaxNLocator(4))
ax.grid(which="major", axis="y", lw=0.2)
# add minor ticks for months, x-axis
ax.xaxis.set_minor_locator(mdates.MonthLocator())
# major ticks for quarters, x-axis
ax.xaxis.set_major_locator(mdates.MonthLocator(bymonthday=1, interval=3))
# label x-axis in Jan 01 format at major ticks
ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %y"))
# add an annotation to the "news" here, the latest peak. Use an arrow.
ax.annotate(
"Worst storm in a decade hits Beijing",
xy=(df.idxmax(), df.max()),
xycoords="data",
xytext=(-30, -40),
textcoords="offset points",
ha="right",
size=14,
arrowprops=dict(arrowstyle="->", connectionstyle="angle3"),
)
# Add a hatch, with a label, that represents the WHO safe level: this helps
# put the current rise in context
ax.fill_between(
x=df.index,
y1=0,
y2=10,
hatch="///",
facecolor="None",
linewidth=0.1,
alpha=0.2,
)
ax.annotate(
"WHO safe level*",
fontweight="heavy",
xy=(0.6, 0.01),
xycoords="axes fraction",
xytext=(0, 0),
textcoords="offset points",
ha="right",
size=12,
)
# Use the FT background colours
fig.set_facecolor("#fff1e4")
ax.set_facecolor("#fff1e4")
# faded tick labels
ax.tick_params(axis="both", which="both", labelsize=12)
plt.setp(ax.get_xticklabels(), alpha=fade_alpha)
plt.setp(ax.get_yticklabels(), alpha=fade_alpha)
plt.show()What do we learn from this exercise? Well, there’s a big last mile issue with making narrative visualisations. Normally, when coding, you should do your best to avoid hard-coding numbers (for example x=0.15 above). But in this case getting things just so requires a lot of manual adjustment and that’s after you’ve decided what story you’re going to tell, and how to show it.
The example above demonstrates many of the pain points of the last mile, like decluttering, adjusting text, careful use of saturation and colour, and ensuring dates are displayed in an aesthetically pleasing way.
Direct Labelling of Lines
One commonly used trick that we didn’t see in the above example is labelling lines directly (rather than using a legend).
Labelling Line Ends
So let’s now see an example of labelling line ends that was originally posted on the Library of Statistical Translation.
df = pd.read_csv(
"https://raw.githubusercontent.com/LOST-STATS/LOST-STATS.github.io/master/Presentation/Figures/Data/Line_Graph_with_Labels_at_the_Beginning_or_End_of_Lines/Research_Nobel_Google_Trends.csv",
parse_dates=["date"],
)
df.head()| date | hits | geo | keyword | name | |
|---|---|---|---|---|---|
| 0 | 2019-09-21 | 1 | world | physics nobel | Physics |
| 1 | 2019-09-22 | 1 | world | physics nobel | Physics |
| 2 | 2019-09-23 | 1 | world | physics nobel | Physics |
| 3 | 2019-09-24 | 1 | world | physics nobel | Physics |
| 4 | 2019-09-25 | 1 | world | physics nobel | Physics |
fade_alpha = 0.7
# Create the column we wish to plot
title = "Log of Google Trends Index"
df[title] = np.log(df["hits"])
df = df.dropna(subset=[title])
# Make a plot
fig, ax = plt.subplots()
# Add lines to it
sns.lineplot(
ax=ax,
data=df,
x="date",
y=title,
hue="name",
palette="deep",
legend=None,
hue_order=df["name"].unique(),
alpha=fade_alpha,
)
# Add the text--for each line, find the end, annotate it with a label, and
# adjust the chart axes so that everything fits on.
for line, name in zip(ax.lines, df["name"].unique()):
y = line.get_ydata()[-1] # NB: to use start value, set [-1] to [0] instead
x = line.get_xdata()[-1]
if not np.isfinite(y):
y = next(reversed(line.get_ydata()[~line.get_ydata().mask]), float("nan"))
if not np.isfinite(y) or not np.isfinite(x):
continue
text = ax.annotate(
name,
xy=(x, y),
xytext=(2, -2),
color=line.get_color(),
xycoords=(ax.get_xaxis_transform(), ax.get_yaxis_transform()),
textcoords="offset points",
fontweight="bold",
)
text_width = (
text.get_window_extent(fig.canvas.get_renderer())
.transformed(ax.transData.inverted())
.width
)
if np.isfinite(text_width):
ax.set_xlim(ax.get_xlim()[0], text.xy[0] + text_width * 1.05)
# Title that gives the narrative
plt.suptitle(
"Economics overtakes other STEM subjects in searches",
size=12,
ha="left",
x=0.12,
)
# Horizontal y-axis title, faded
ax.set_title(
title,
loc="left",
size=10,
alpha=fade_alpha,
)
ax.set_xlabel("")
ax.set_ylabel("")
# remove chart clutter
for key, spine in ax.spines.items():
spine.set_visible(False)
ax.tick_params(axis="y", which="both", length=0)
ax.tick_params(axis="x", which="both", color=[1, 0, 0, fade_alpha])
# Format the date axis to be prettier.
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d %b"))
ax.xaxis.set_minor_locator(mdates.DayLocator())
ax.xaxis.set_major_locator(mdates.AutoDateLocator(interval_multiples=False))
# for time series, tick marks on the right help give a sense of right-ward motion
ax.yaxis.tick_right()
ax.grid(which="major", axis="y", lw=0.2)
ax.set_ylim(0, None)
plt.show()Labelling On Lines
Another powerful way to remove the need for a legend is to put labels directly on the lines. Here’s an example from the online book Scientific Visualization in Matplotlib (Rougier 2021).
def plot_helper_func(ax, X, C, S):
ax.set_xlim([-np.pi, np.pi])
ax.set_xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi])
ax.set_xticklabels(["-π", "-π/2", "0", "+π/2", "+π"])
ax.set_ylim([-1, 1])
ax.set_yticks([-1, 0, 1])
ax.set_yticklabels(["-1", "0", "+1"])
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_position(("data", -3.25))
ax.spines["bottom"].set_position(("data", -1.25))
(plot1,) = ax.plot(X, C, label="cosine", clip_on=False, lw=3)
(plot2,) = ax.plot(X, S, label="sine", clip_on=False, lw=3)
return plot1, plot2
X = np.linspace(-np.pi, np.pi, 400, endpoint=True)
C, S = np.cos(X), np.sin(X)
kw_settings = dict(
size="large",
bbox=dict(facecolor="white", edgecolor="None", alpha=0.85),
ha="center",
va="center",
rotation=60,
)
fig, ax = plt.subplots()
plot1, plot2 = plot_helper_func(ax, X, C, S)
ax.text(
X[100],
C[100],
" " + plot1.get_label(),
color=plot1.get_color(),
**kw_settings,
)
ax.text(
X[200],
S[200],
" " + plot2.get_label(),
color=plot2.get_color(),
**kw_settings,
);The Economist
This example comes from a blog post called Making Economist-Style Plots in Matplotlib by Robert Ritz. Our objective is to recreate a plot from The Economist.
In terms of narrative bells and whistles, this chart uses a less is more approach. But this is on purpose; there’s very little chart clutter so all we see is the order of GDP levels.
We’ll use the Varta font, which is somewhat similar to the proprietary font used by The Economist. If you are going to use a custom font for a chart, you need to either download and install it on your system or tell matplotlib where the file is (we’ll use the former approach below).
The data is from the World Bank and is available using the code NY.GDP.MKTP.CD.
from datetime import datetime
import wbgapi as wb
end_year = datetime.now().year
countries = ["GBR", "USA", "CHN", "IND", "FRA", "CAN", "KOR", "DEU", "ITA"]
indicator_code = "NY.GDP.MKTP.CD"
df = (
wb.data.DataFrame(
indicator_code,
countries,
time=range(end_year - 2, end_year + 1),
labels=True,
numericTimeKeys=True,
)
.rename(columns={"Country": "country"})
.reset_index(drop=True)
.melt(id_vars="country", var_name="year", value_name=indicator_code)
)
df.head()| country | year | NY.GDP.MKTP.CD | |
|---|---|---|---|
| 0 | Italy | 2024 | 2.380825e+12 |
| 1 | Germany | 2024 | 4.685593e+12 |
| 2 | Korea, Rep. | 2024 | 1.875388e+12 |
| 3 | Canada | 2024 | 2.243637e+12 |
| 4 | France | 2024 | 3.160443e+12 |
Let’s do some quick tidying and prep of the data ready for plotting.
df["gdp_trillions"] = df["NY.GDP.MKTP.CD"] / 1e12
gdp = df[df["year"] == df["year"].max()].sort_values(by="gdp_trillions").tail(9)
gdp.head()| country | year | NY.GDP.MKTP.CD | gdp_trillions | |
|---|---|---|---|---|
| 9 | Italy | 2025 | NaN | NaN |
| 10 | Germany | 2025 | NaN | NaN |
| 11 | Korea, Rep. | 2025 | NaN | NaN |
| 12 | Canada | 2025 | NaN | NaN |
| 13 | France | 2025 | NaN | NaN |
Before making the chart, let’s set the font:
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Varta"Now let’s make the chart. Note that we need to paint on the red rectangle and, to get everything in just the right spot, specify some points individually.
# Setup plot size.
fig, ax = plt.subplots(figsize=(3, 6))
# Create grid
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
ax.grid(which="major", axis="x", color="#758D99", alpha=0.6, zorder=1)
# Remove splines. Can be done one at a time or can slice with a list.
ax.spines[["top", "right", "bottom"]].set_visible(False)
# Make left spine slightly thicker
ax.spines["left"].set_linewidth(1.1)
# Setup data
gdp["country"] = gdp["country"].replace("the United States", "United States")
gdp_bar = gdp[gdp["year"] == gdp["year"].max()].sort_values(by="gdp_trillions")[-7:]
# Plot data
ax.barh(gdp_bar["country"], gdp_bar["gdp_trillions"], color="#006BA2", zorder=2, lw=0)
# Set custom labels for x-axis
labels = np.arange(0, 25, 5)
ax.set_xticks(labels)
ax.set_xticklabels(labels)
# Reformat x-axis tick labels
ax.xaxis.set_tick_params(
labeltop=True, # Put x-axis labels on top
labelbottom=False, # Set no x-axis labels on bottom
bottom=False, # Set no ticks on bottom
labelsize=11, # Set tick label size
pad=-1,
) # Lower tick labels a bit
# Reformat y-axis tick labels
ax.set_yticklabels(
gdp_bar["country"], # Set labels again
ha="left",
) # Set horizontal alignment to left
ax.yaxis.set_tick_params(
pad=100, # Pad tick labels so they don't go over y-axis
labelsize=11, # Set label size
bottom=False,
) # Set no ticks on bottom/left
leftmost_pos = -0.35
top_most_point = 1.02
# Add in line and tag
ax.plot(
[leftmost_pos, 0.87], # Set width of line
[top_most_point, top_most_point], # Set height of line
transform=fig.transFigure, # Set location relative to plot
clip_on=False,
color="#E3120B",
linewidth=0.6,
)
ax.add_patch(
plt.Rectangle(
(
leftmost_pos,
top_most_point,
), # Set location of rectangle by lower left corder
0.12, # Width of rectangle
-0.02, # Height of rectangle. Negative so it goes down.
facecolor="#E3120B",
transform=fig.transFigure,
clip_on=False,
linewidth=0,
)
)
# Add in title and subtitle
ax.text(
x=leftmost_pos,
y=0.96,
s="The big leagues",
transform=fig.transFigure,
ha="left",
fontsize=13,
weight="bold",
)
ax.text(
x=leftmost_pos,
y=0.925,
s=f"{gdp['year'].max()} GDP, trillions of USD",
transform=fig.transFigure,
ha="left",
fontsize=11,
)
# Set source text
ax.text(
x=leftmost_pos,
y=0.08,
s="Source: World Bank",
transform=fig.transFigure,
ha="left",
fontsize=9,
alpha=0.7,
)
plt.show()