# How to create a custom legend with matplotlib

We've learned how to create legends for our charts using either labels in each plot, or by manually assigning labels to plots.

Using a label in each plot we did something like this:

...

axes.bar(
    ...
    label="Men
)
axes.bar(
    ...
    label="Women"
)

axes.legend()

...

Manually assigning labels to plots happened like this:

...

men = axes.bar(...)
women = axes.bar(...)

axes.legend((men, women), ("Men", "Women"))

...

# Creating custom legends

Sometimes though, we want to create custom legends that don't necessary map to individual plots.

For example, if we create a single bar chart but we pass in multiple colours so that different bars are coloured differently:

import matplotlib.pyplot as plt

figure = plt.figure()
axes = figure.add_subplot()

axes.bar(
    range(6),
    [150, 90, 78, 55, 123, 190],
    tick_label=["Apple", "Burberry", "Google", "Zara", "Microsoft", "Superdry"],
    color=["#5c44fd", "#ff5566", "#5c44fd", "#ff5566", "#5c44fd", "#ff5566"]
)

plt.xticks(rotation=30, ha="right")
plt.show()

This produces a bar chart like this one:

Bar chart of tech and clothing companies

It would be nice if we could have a legend that maps blue to tech, and red to clothing companies. But since this is a single plot, neither of the approaches we covered earlier on would work.

That's where custom legends come into play!

# Creating a custom colour Patch

Each legend element (also called 'handle') is composed of two things: a patch of colour that tells us which plotted element the legend is for, and a label.

Legend example image

To create a custom legend, we first need to create those standalone patches of colour. We do with with the Patch class:

from matplotlib.patches import Patch

handles = [
    Patch(facecolor="#5c44fd", label="Tech"),
    Patch(facecolor="#ff5566", label="Clothing")
]

# Using the patches in a custom legend

Then we can use those handles in a custom legend, like so:

plt.legend(handles=handles)

Easy enough!

The complete code looks like this:

import matplotlib.pyplot as plt
from matplotlib.patches import Patch

figure = plt.figure()
axes = figure.add_subplot()

axes.bar(
    range(6),
    [150, 90, 78, 55, 123, 190],
    tick_label=["Apple", "Burberry", "Google", "Zara", "Microsoft", "Superdry"],
    color=["#5c44fd", "#ff5566", "#5c44fd", "#ff5566", "#5c44fd", "#ff5566"]
)

handles = [
    Patch(facecolor="#5c44fd", label="Tech"),
    Patch(facecolor="#ff5566", label="Clothing")
]

axes.legend(handles=handles)

plt.xticks(rotation=30, ha="right")
plt.show()

And the result:

Clothing vs. tech bar chart with legend