How to draw animated legend for subplots?

2020-07-24 06:06发布

问题:

I would like to draw animated subplots with ArtistAnimation. Unfortunately, I cannot figure out how to have an animated legend. I tried out different methods I found on StackOverflow. If I manage to get a legend it is not animated, but just the legends of all animation steps together.

My code looks like this:

import numpy as np
import pylab as pl
import matplotlib.animation as anim

fig, (ax1, ax2, ax3) = pl.subplots(1,3,figsize=(11,4))
ims   = []
im1   = ['im11','im12','im13']
im2   = ['im21','im22','im23']
x = np.arange(0,2*np.pi,0.1)

n=50
for i in range(n):
    for sp in (1,2,3):
        pl.subplot(1,3,sp)

        y1 = np.sin(sp*x + i*np.pi/n)
        y2 = np.cos(sp*x + i*np.pi/n)

        im1[sp-1], = pl.plot(x,y1)
        im2[sp-1], = pl.plot(x,y2)

        pl.xlim([0,2*np.pi])
        pl.ylim([-1,1])

        lab = 'i='+str(i)+', sp='+str(sp)
        im1[sp-1].set_label([lab])
        pl.legend(loc=2, prop={'size': 6}).draw_frame(False)

    ims.append([ im1[0],im1[1],im1[2], im2[0],im2[1],im2[2] ])

ani = anim.ArtistAnimation(fig,ims,blit=True)
pl.show()

I thought this code would be equivalent to the method used here How to add legend/label in python animation but obviously I am missing something.

I also tried to set the labels as suggested in Add a legend for an animation (of Artists) in matplotlib but I do not really understand how to use it for my case. Like this

im2[sp-1].legend(handles='-', labels=[lab])

I get an AttributeError: 'Line2D' object has no attribute 'legend'.

[EDIT]: I did not state it clearly: I would like to have a legend for both lines in the plots.

回答1:

I do not know what exactly the legend should look like, but I'd imaginge you simply want to let it display the current value of the one line from the current frame. You'd therefore better update the data of the line, instead of plotting 150 new plots.

import numpy as np
import pylab as plt
import matplotlib.animation as anim

fig, axes = plt.subplots(1,3,figsize=(8,3))
ims   = []
im1   = [ax.plot([],[], label="label")[0] for ax in axes]
im2   = [ax.plot([],[], label="label")[0] for ax in axes]
x = np.arange(0,2*np.pi,0.1)

legs = [ax.legend(loc=2, prop={'size': 6})  for ax in axes]

for ax in axes:
    ax.set_xlim([0,2*np.pi])
    ax.set_ylim([-1,1])
plt.tight_layout()
n=50
def update(i):
    for sp in range(3):
        y1 = np.sin((sp+1)*x + (i)*np.pi/n)
        y2 = np.cos((sp+1)*x + (i)*np.pi/n)

        im1[sp].set_data(x,y1)
        im2[sp].set_data(x,y2)

        lab = 'i='+str(i)+', sp='+str(sp+1)
        legs[sp].texts[0].set_text(lab)
        legs[sp].texts[1].set_text(lab)

    return im1 + im2 +legs 

ani = anim.FuncAnimation(fig,update, frames=n,blit=True)
plt.show()