# How to create a legend from your data
Matplotlib can easily generate a legend for our axes. It does so by using the label=
argument in the .bar()
call (or any other axes-generating plot call):
import matplotlib.pyplot as plt
from data import polls
poll_titles = [poll[0] for poll in polls]
poll_men = [poll[1] for poll in polls]
poll_women = [poll[2] for poll in polls]
poll_x_coordinates = range(len(polls))
figure = plt.figure(figsize=(6, 6))
figure.subplots_adjust(bottom=0.35)
axes = figure.add_subplot()
axes.bar(
poll_x_coordinates,
poll_men,
label="Men"
)
axes.bar(
poll_x_coordinates,
poll_women,
bottom=poll_men,
label="Women"
)
plt.xticks(poll_x_coordinates, poll_titles, rotation=30, ha="right")
plt.show()
Next up, we tell the axes to draw a legend using the data available:
import matplotlib.pyplot as plt
from data import polls
poll_titles = [poll[0] for poll in polls]
poll_men = [poll[1] for poll in polls]
poll_women = [poll[2] for poll in polls]
poll_x_coordinates = range(len(polls))
figure = plt.figure(figsize=(6, 6))
figure.subplots_adjust(bottom=0.35)
axes = figure.add_subplot()
axes.bar(
poll_x_coordinates,
poll_men,
label="Men"
)
axes.bar(
poll_x_coordinates,
poll_women,
bottom=poll_men,
label="Women"
)
axes.legend()
plt.xticks(poll_x_coordinates, poll_titles, rotation=30, ha="right")
plt.show()
That was the first and simplest way of defining the legend. However, sometimes it can be useful to manually assign a legend item to individual plots. We can do this by first creating variables for each plot that we want to include in the legend. We don't need the label=
argument for this:
import matplotlib.pyplot as plt
from data import polls
poll_titles = [poll[0] for poll in polls]
poll_men = [poll[1] for poll in polls]
poll_women = [poll[2] for poll in polls]
poll_x_coordinates = range(len(polls))
figure = plt.figure(figsize=(6, 6))
figure.subplots_adjust(bottom=0.35)
axes = figure.add_subplot()
men_plot = axes.bar(
poll_x_coordinates,
poll_men
)
women_plot = axes.bar(
poll_x_coordinates,
poll_women,
bottom=poll_men
)
plt.xticks(poll_x_coordinates, poll_titles, rotation=30, ha="right")
plt.show()
Then we can call axes.legend()
but with some arguments: the plots and a label for each plot:
import matplotlib.pyplot as plt
from data import polls
poll_titles = [poll[0] for poll in polls]
poll_men = [poll[1] for poll in polls]
poll_women = [poll[2] for poll in polls]
poll_x_coordinates = range(len(polls))
figure = plt.figure(figsize=(6, 6))
figure.subplots_adjust(bottom=0.35)
axes = figure.add_subplot()
men_plot = axes.bar(
poll_x_coordinates,
poll_men
)
women_plot = axes.bar(
poll_x_coordinates,
poll_women,
bottom=poll_men
)
axes.legend((men_plot, women_plot), ("Men", "Women"))
plt.xticks(poll_x_coordinates, poll_titles, rotation=30, ha="right")
plt.show()
Doing this means that the legend is defined in one place, which might be preferrable in terms of readability.