# How to add multiple subplots to a figure

We've learned that a Figure can contain multiple sets of Axes, each with their own plots.

The documentation often calls a set of Axes a "subplot".

Hence, a Figure can contain multiple subplots:

import matplotlib.pyplot as plt

figure = plt.figure()
ax1 = figure.add_subplot()
ax2 = figure.add_subplot()

ax1.plot([1, 2, 3, 4], [3, 5, 9, 25])
ax2.plot([1, 2, 3, 4], [5, 7, 11, 17])

plt.show()

This code looks like it creates two subplots within the Figure object, but actually if we run the code, we see this:

Single set of axes with two plots

The reason for that is that when we call .add_subplot(), we need to specify the location of the subplot within the figure.

The location of each suplot is specified as three numbers: (nrows, ncols, index).

  • nrows determines the number of rows in the figure.
  • ncols determines the number of columns in the figure.
  • index places the newly added subplot in a specific cell.

We can modify the code above to place the two subplots in adjacent cells within one row:

import matplotlib.pyplot as plt

figure = plt.figure()
ax1 = figure.add_subplot(1, 2, 1)
ax2 = figure.add_subplot(1, 2, 2)

ax1.plot([1, 2, 3, 4], [3, 5, 9, 25])
ax2.plot([1, 2, 3, 4], [5, 7, 11, 17])

plt.show()

The first subplot is on (1, 2, 1). This means the figure has 1 row, 2 columns, and this subplot is in cell 1. That's the cell at the top left.

The second subplot is on (1, 2, 2). This keeps the figure unchanged (since the number of rows and columns is the same as before), but places this subplot in cell 2. That's the next cell to the right.

The index value starts at the top left, and moves right row by row as shown by this diagram:

Row index diagram

TIP

You'll often see this: figure.add_subplot(111).

Matplotlib allows us to pass a three-number code instead of separating them with commas. I strongly dislike this, and I would recommend avoiding this. However, it's good to know as a lot of people do it!

Also, this doesn't work when the total number of cells is greater than 9.

# Using plt.subplots() as a shorthand

We're doing this very often when working with multiple subplots:

figure = plt.figure()
ax1 = figure.add_subplot(1, 2, 1)
ax2 = figure.add_subplot(1, 2, 2)

So matplotlib has a shorthand that we can use:

figure, (ax1, ax2) = plt.subplots(1, 2)

This creates a figure with 1 row and 2 columns. It returns the figure and a tuple of axes in their usual cell order. Therefore, ax1 is the one at the top left, and ax2 is the next one to its right.

If we did this:

figure, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)

Then ax1 and ax2 would occupy the two cells in the first row, and ax3 and ax4 would occupy the two cells in the second row.

Note that we have a tuple of tuples when using more than one row!