Chapter 2 Making beautiful charts in Python
This chapter contains the code for some of my most used charts and visualization techniques.
2.1 Importing python packages
Let’s load in some libraries that we will use again and again when making charts.
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
import statistics
from scipy.stats import norm
from matplotlib.ticker import EngFormatter, StrMethodFormatter
2.2 Reading and cleaning data
Let’s start by importing data from a csv and making it usable. In this example, we’ll use the weather profile from 2019 in Melbourne, Australia.
We’ll also create a new column for a rolling average of the temperature.
#Note non-ascii character in csv will stuff up the import, so we add this term: encoding='unicode_escape'
# Note: The full file location is this:
# /Users/charlescoverdale/Documents/2021/Python_code_projects/learning_journal_v0-1/MEL_weather_2019.csv
# Import csv
= pd.read_csv("MEL_weather_2019.csv",encoding='unicode_escape')
df_weather
# Create a single data column and bind to df
'Date'] = pd.to_datetime(df_weather[['Year', 'Month', 'Day']])
df_weather[
# Drop the original three field date columns
= df_weather.drop(columns=['Year', 'Month', 'Day'])
df_weather
# Let's change the name of the solar exposure column
= df_weather.rename({'Daily global solar exposure (MJ/m*m)':'Solar_exposure',
df_weather 'Rainfall amount (millimetres)':'Rainfall',
'Maximum temperature (°C)': 'Max_temp'},
=1)
axis
#Add a rolling average
'Rolling_avg'] = df_weather['Max_temp'].rolling(window=7).mean()
df_weather[
df_weather.head()
2.3 Line charts
Now that the data is in a reasonable format (e.g. there is a simple to use ‘Date’ column), let’s go ahead and make a line chart.
# Now let's plot maximum temperature on a line chart
'Date'], df_weather['Max_temp'],
plt.plot(df_weather[='Maximum temperature',
label='blue',
color=0.2,
alpha=1.0,
linewidth='')
marker
'Date'], df_weather['Rolling_avg'],
plt.plot(df_weather[='7-day moving average',
label='red',
color=1.0,
linewidth='')
marker
'Maximum temperature in Melbourne (2019)', fontsize=12)
plt.title(
'', fontsize=10)
plt.xlabel('%b'))
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter(=1))
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval#plt.margins(x=0)
'', fontsize=10)
plt.ylabel(u"{x:.0f}°C"))
plt.gca().yaxis.set_major_formatter(StrMethodFormatter(
'top'].set_visible(False)
plt.gca().spines['bottom'].set_visible(True)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().spines[
plt.tick_params(='x', # changes apply to the x-axis
axis='both', # both major and minor ticks are affected
which=False, # ticks along the bottom edge are off
bottom=False, # ticks along the top edge are off
top=True) # labels along the bottom edge are off
labelbottom
plt.tick_params(='y', # changes apply to the y-axis
axis='both', # both major and minor ticks are affected
which=False, # ticks along the bottom edge are off
left=False, # ticks along the top edge are off
right=True) # labels along the bottom edge are off
labelleft
False)
plt.grid(True)
plt.gca().yaxis.grid(
=False, framealpha=1, shadow=False, borderpad=1)
plt.legend(fancybox
'weather_chart_save.png',dpi=300,bbox_inches='tight')
plt.savefig(
plt.show()
2.4 Bar charts
# Chart 1: Bar plot
# Get data
= ['USA', 'Canada', 'Germany', 'UK', 'France']
country = [45,40,38,16,10]
GDP_per_capita
# Create plot
=0.8, align='center',color='blue', edgecolor = 'black')
plt.bar(country, GDP_per_capita, width
# Labels and titles
'GDP per capita of select OECD countries')
plt.title('Test x label')
plt.xlabel('')
plt.ylabel(
#A dd bar annotations to barchart
# Location for the annotated text
= 1.0
i = 1.0
j
# Annotating the bar plot with the values (total death count)
for i in range(len(country)):
-0.1 + i, GDP_per_capita[i] + j))
plt.annotate(GDP_per_capita[i], (
# Creating the legend of the bars in the plot
= ['GDP_per_capita'])
plt.legend(labels
# Remove y the axis
plt.yticks([])
#
'test_bar_plot.png',dpi=300,bbox_inches='tight')
plt.savefig(
# Show plot
plt.show()
# Saving the plot as a 'png'
#plt.savefig('testbarplot.png')
2.5 Stacked bar charts
= ['Group 1', 'Group 2', 'Group 3', 'Group 4', 'Group 5']
labels = [20, 35, 30, 35, 27]
men_means = [25, 32, 34, 20, 25]
women_means = [2, 3, 4, 1, 2]
men_std = [3, 5, 2, 3, 3]
women_std = 0.7 # the width of the bars: can also be len(x) sequence
width
= plt.subplots()
fig, ax
=men_std, label='Men') ax.bar(labels, men_means, width, yerr
=women_std, bottom=men_means,
ax.bar(labels, women_means, width, yerr='Women') label
'Scores')
ax.set_ylabel('Scores by group and gender')
ax.set_title(
ax.legend()
plt.show()
2.6 Line charts (from raw data)
import matplotlib.ticker as mtick
# Note: you can also get the same result without using a pandas dataframe
#Year = [1920,1930,1940,1950,1960,1970,1980,1990,2000,2010]
#Unemployment_Rate = [9.8,12,8,7.2,6.9,7,6.5,6.2,5.5,6.3]
#Using a pandas dataframe
= {'Year': [1920,1930,1940,1950,1960,1970,1980,1990,2000,2010],
Data 'Unemployment_Rate': [9.8,12,8,7.2,6.9,7,6.5,6.2,5.5,6.3]
}
= pd.DataFrame(Data,columns=['Year','Unemployment_Rate'])
df
#Add in a % sign to a new variable
#df['Unemployment_Rate_Percent'] = df['Unemployment_Rate'].astype(str) + '%'
'Year'], df['Unemployment_Rate'], color='blue', marker='o')
plt.plot(df['Unemployment rate (1920-2010)', fontsize=12)
plt.title('Year', fontsize=12)
plt.xlabel('', fontsize=12)
plt.ylabel(#plt.grid(False)
True)
plt.gca().yaxis.grid(
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter()) plt.show()
2.7 Scatter plot
=[5, 7, 8, 7, 2, 17, 2, 9,
x 4, 11, 12, 9, 6]
=[99, 86, 87, 88, 100, 86,
y 103, 87, 94, 78, 77, 85, 86]
="blue")
plt.scatter(x, y, c
'Scatterplot title', fontsize=12)
plt.title('x label', fontsize=12)
plt.xlabel('y label', fontsize=12)
plt.ylabel(
plt.show()
2.8 Histogram
99)
np.random.seed(
# Using the format np.random.normal(mu, sigma, 1000)
= np.random.normal(0,1,size=1000)
x
# Use density=False for counts, and density=True for probability
=False, bins=100)
plt.hist(x, density
# Plot mean line
='k', linestyle='dashed', linewidth=1)
plt.axvline(x.mean(), color
'Probability')
plt.ylabel('Mean');
plt.xlabel(
plt.show()
2.9 Multiple charts in single plot
= plt.subplots(ncols=2)
fig, (ax,ax2)
0,1],[-35,30])
ax.plot([=u"°C"))
ax.yaxis.set_major_formatter(EngFormatter(unit
0,1],[-35,30])
ax2.plot([u"{x:.0f} °C"))
ax2.yaxis.set_major_formatter(StrMethodFormatter(
plt.tight_layout() plt.show()
2.10 Annotating charts
Example taken from the wonderful blog at Practical Economics.
'Employment Impact of a Minimum Wage')
plt.title(
# Set limits of chart
10,70) plt.xlim(
130,200)
plt.ylim(
# Wage supply floor
10,30],[150,150],color='orange')
plt.plot([10.5,140.0,"Marginal\nDisutility\nof Labour",size=8,color='black')
plt.text(
10,40],[160,160],color='lightgrey',linestyle='--')
plt.plot([40,40],[130,160],color='lightgrey',linestyle='--')
plt.plot([
'', xy=(30,138),xytext=(40,138),arrowprops = dict(arrowstyle='<->'))
plt.annotate(31,140,"Employment\nLoss",size=8, color='k')
plt.text(
170,150,xmin=0.0,xmax=20/60,alpha=0.9,color='dodgerblue')
plt.axhspan('Additional Surplus to\nEmployed', xy=(20,162),xytext=(30,185),arrowprops = dict(arrowstyle='->'))
plt.annotate(
# Deadweight loss triangles
=[30,30,40,30]
trianglex=[150,170,160,150]
triangley='grey')
plt.plot(trianglex,triangley, color='grey')
plt.fill(trianglex,triangley,color
# Main box
10,30],[170,170],'tab:orange')
plt.plot([30,30],[130,170],'tab:green')
plt.plot([#plt.plot([50,50],[130,170],'tab:red')
11,171,"Wage Rate",size=8,color='black')
plt.text(
'Deadweight\nLoss', xy=(32,162),xytext=(38,175),arrowprops = dict(arrowstyle='->'))
plt.annotate(
#Labour Demand Curve
20,60],[180,140],color='tab:grey')
plt.plot([61,135,"Marginal\nProduct\nof Labour\nDemand",size=8,color='black')
plt.text(
#Labour Supply Curve
20,60],[140,180],color='tab:grey')
plt.plot([61,180,"Labour\nSupply",size=8,color='k')
plt.text(
plt.show()
2.11 Mimicking The Economist
The visual storytelling team at The Economist is absolutely world class. Their team is quite public about how they use both R and Python in their data science.
Robert Ritz has done an outstanding job at documenting how you can use their style when making charts.
The dataset we’ll use is the GDP records from 1960-2020.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# This makes out plots higher resolution, which makes them easier to see while building
'figure.dpi'] = 100
plt.rcParams[
= pd.read_csv('data/gdp_1960_2020.csv')
gdp
gdp.head()
# The GDP numbers here are very long. To make them easier to read we can divide the GDP number by 1 trillion.
'gdp_trillions'] = gdp['gdp'] / 1_000_000_000_000
gdp[
# Now we can filter for only 2020 and grab the bottom 9. We do this instead of sorting by descending because Matplotlib plots from the bottom to top, so we actually want our data in reverse order.
'year'] == 2020].sort_values(by='gdp_trillions')[-9:]
gdp[gdp[
# Setup plot size.
#fig, ax = plt.subplots(figsize=(3,6))
"figure.figsize"] = (3,6)
plt.rcParams[
# Create grid
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
="major", axis='x', color='#758D99', alpha=0.6, zorder=1)
ax.grid(which
# Remove splines. Can be done one at a time or can slice with a list.
'top','right','bottom']].set_visible(False)
ax.spines[[
# Make left spine slightly thicker
'left'].set_linewidth(1.1)
ax.spines['left'].set_linewidth(1.1)
ax.spines[
# Setup data
'country'] = gdp['country'].replace('the United States', 'United States')
gdp[= gdp[gdp['year'] == 2020].sort_values(by='gdp_trillions')[-9:]
gdp_bar
# Plot data
'country'], gdp_bar['gdp_trillions'], color='#006BA2', zorder=2)
ax.barh(gdp_bar[
# Set custom labels for x-axis
0, 5, 10, 15, 20])
ax.set_xticks([0, 5, 10, 15, 20])
ax.set_xticklabels([
# Reformat x-axis tick labels
=True, # Put x-axis labels on top
ax.xaxis.set_tick_params(labeltop=False, # Set no x-axis labels on bottom
labelbottom=False, # Set no ticks on bottom
bottom=11, # Set tick label size
labelsize=-1) # Lower tick labels a bit
pad
# Reformat y-axis tick labels
'country'], # Set labels again
ax.set_yticklabels(gdp_bar[= 'left') # Set horizontal alignment to left
ha =100, # Pad tick labels so they don't go over y-axis
ax.yaxis.set_tick_params(pad=11, # Set label size
labelsize=False) # Set no ticks on bottom/left
bottom
# Shrink y-lim to make plot a bit tighter
-0.5, 8.5)
ax.set_ylim(
# Add in line and tag
-.35, .87], # Set width of line
ax.plot([1.02, 1.02], # Set height of line
[=fig.transFigure, # Set location relative to plot
transform=False,
clip_on='#E3120B',
color=.6)
linewidth
-.35,1.02), # Set location of rectangle by lower left corner
ax.add_patch(plt.Rectangle((0.12, # Width of rectangle
-0.02, # Height of rectangle. Negative so it goes down.
='#E3120B',
facecolor=fig.transFigure,
transform=False,
clip_on= 0))
linewidth
# Add in title and subtitle
=-.35, y=.96, s="The big leagues", transform=fig.transFigure, ha='left', fontsize=13, weight='bold', alpha=.8)
ax.text(x=-.35, y=.925, s="2020 GDP, trillions of USD", transform=fig.transFigure, ha='left', fontsize=11, alpha=.8)
ax.text(x
# Set source text
=-.35, y=.08, s="""Source: "GDP of all countries (1960-2020)""", transform=fig.transFigure, ha='left', fontsize=9, alpha=.7)
ax.text(x
plt.show()
# Export plot as high resolution PNG
'docs/economist_bar.png', # Set path and filename
plt.savefig(= 300, # Set dots per inch
dpi ="tight", # Remove extra whitespace around plot
bbox_inches='white') # Set background color to white facecolor
We can do a similar process for line charts.
= gdp[gdp['year'] == 2020].sort_values(by='gdp_trillions')[-9:]['country'].values
countries countries
'date'] = pd.to_datetime(gdp['year'], format='%Y')
gdp[
# Setup plot size.
= plt.subplots(figsize=(8,4))
fig, ax
# Create grid
# Zorder tells it which layer to put it on. We are setting this to 1 and our data to 2 so the grid is behind the data.
="major", axis='y', color='#758D99', alpha=0.6, zorder=1)
ax.grid(which
# Plot data
# Loop through country names and plot each one.
for country in countries:
'country'] == country]['date'],
ax.plot(gdp[gdp['country'] == country]['gdp_trillions'],
gdp[gdp[='#758D99',
color=0.8,
alpha=3)
linewidth
# Plot US and China separately
'country'] == 'United States']['date'],
ax.plot(gdp[gdp['country'] == 'United States']['gdp_trillions'],
gdp[gdp[='#006BA2',
color=3)
linewidth
'country'] == 'China']['date'],
ax.plot(gdp[gdp['country'] == 'China']['gdp_trillions'],
gdp[gdp[='#3EBCD2',
color=3)
linewidth
# Remove splines. Can be done one at a time or can slice with a list.
'top','right','left']].set_visible(False)
ax.spines[[
# Shrink y-lim to make plot a bit tigheter
0, 23)
ax.set_ylim(
# Set xlim to fit data without going over plot area
1958, 1, 1), pd.datetime(2023, 1, 1))
ax.set_xlim(pd.datetime(
# Reformat x-axis tick labels
=11) # Set tick label size
ax.xaxis.set_tick_params(labelsize
# Reformat y-axis tick labels
0,25,5), # Set labels again
ax.set_yticklabels(np.arange(= 'right', # Set horizontal alignment to right
ha ='bottom') # Set vertical alignment to make labels on top of gridline
verticalalignment
=-2, # Pad tick labels so they don't go over y-axis
ax.yaxis.set_tick_params(pad=True, # Put x-axis labels on top
labeltop=False, # Set no x-axis labels on bottom
labelbottom=False, # Set no ticks on bottom
bottom=11) # Set tick label size
labelsize
# Add labels for USA and China
=.63, y=.67, s='United States', transform=fig.transFigure, size=10, alpha=.9)
ax.text(x=.7, y=.4, s='China', transform=fig.transFigure, size=10, alpha=.9)
ax.text(x
# Add in line and tag
0.12, .9], # Set width of line
ax.plot([.98, .98], # Set height of line
[=fig.transFigure, # Set location relative to plot
transform=False,
clip_on='#E3120B',
color=.6)
linewidth0.12,.98), # Set location of rectangle by lower left corder
ax.add_patch(plt.Rectangle((0.04, # Width of rectangle
-0.02, # Height of rectangle. Negative so it goes down.
='#E3120B',
facecolor=fig.transFigure,
transform=False,
clip_on= 0))
linewidth
# Add in title and subtitle
=0.12, y=.91, s="Ahead of the pack", transform=fig.transFigure, ha='left', fontsize=13, weight='bold', alpha=.8)
ax.text(x=0.12, y=.86, s="Top 9 GDP's by country, in trillions of USD, 1960-2020", transform=fig.transFigure, ha='left', fontsize=11, alpha=.8)
ax.text(x
# Set source text
=0.12, y=0.01, s="""Source: GDP of all countries (1960-2020)""", transform=fig.transFigure, ha='left', fontsize=9, alpha=.7)
ax.text(x
# Export plot as high resolution PNG
'docs/economist_line.png', # Set path and filename
plt.savefig(= 300, # Set dots per inch
dpi ="tight", # Remove extra whitespace around plot
bbox_inches='white') # Set background color to white
facecolor
plt.show()