I am making a scatter plot from three separate dataframes and plotting the points as well as the best fit lines. I can accomplish this using this code:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
fig=plt.figure()
ax1=fig.add_subplot(111)
ax2=fig.add_subplot(111)
ax3=fig.add_subplot(111)
#create scatter plots from the dataframes
ax1.scatter(ex_x, ex_y, s=10, c='r', label='Fire Exclusion')
ax2.scatter(one_x,one_y, c='b', marker='s',label='One Fire')
ax3.scatter(two_x, two_y, s=10, c='g', marker='^', label='Two Fires')
#plot lines of best fit
ax1.plot(ex_x,ex_results.predict(), color = 'r',label = 'Linear (Fire Exclusion)')
ax2.plot(one_x,one_results.predict(), color = 'b',label = 'Linear (One Fire)')
ax3.plot(two_x,two_results.predict(), color = 'g',label = 'Linear (Two Fires)')
#add legend and axis labels
plt.xlabel('NDVI 2004/07/27')
plt.ylabel('NDVI 2005/07/14')
plt.title('NDVI in 2004 vs. 2005')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), scatterpoints=1)
which gives me:
Now I want to add a second legend which will display the r2 for each line. I am attempting to do that like this:
fig=plt.figure()
ax1=fig.add_subplot(111)
ax2=fig.add_subplot(111)
ax3=fig.add_subplot(111)
scat1,=ax1.scatter(ex_x, ex_y, s=10, c='r', label='Fire Exclusion')
scat2,=ax2.scatter(one_x,one_y, c='b', marker='s',label='One Fire')
scat3,=ax3.scatter(two_x, two_y, s=10, c='g', marker='^', label='Two Fires')
lin1,=ax1.plot(ex_x,ex_results.predict(), color = 'r',label = 'Linear (Fire Exclusion)')
lin2,=ax2.plot(one_x,one_results.predict(), color = 'b',label = 'Linear (One Fire)')
lin3,=ax3.plot(two_x,two_results.predict(), color = 'g',label = 'Linear (Two Fires)')
l1 = plt.legend([scat1, scat2,scat3,lin1,lin2,lin3], ["Fire Exclusion", "One Fire", "Two Fires", "Linear (Fire Exclusion)", "Linear (One Fire)", "Linear (Two Fires)"], loc='upper left', scatterpoints=1)
#get r2 from regression results
r2ex=ex_results.rsquared
r2one=one_results.rsquared
r2two=two_results.rsquared
plt.legend([r2ex, r2one, r2two], ['R2 (Fire Exclusion)', 'R2 (One Fire)', 'R2 (Two Fires)'], loc='lower right')
plt.gca().add_artist(l1)
plt.xlabel('NDVI 2004/07/27')
plt.ylabel('NDVI 2005/07/14')
plt.title('NDVI in 2004 vs. 2005')
but this returns:
Traceback (most recent call last):
File "<ipython-input-32-b6277bf27ded>", line 1, in <module>
runfile('E:/prelim_codes/Fire.py', wdir='E:/prelim_codes')
File "C:\Users\Stefano\Anaconda2_2\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 714, in runfile
execfile(filename, namespace)
File "C:\Users\Stefano\Anaconda2_2\lib\site-packages\spyderlib\widgets\externalshell\sitecustomize.py", line 74, in execfile
exec(compile(scripttext, filename, 'exec'), glob, loc)
File "E:/prelim_codes/Fire.py", line 539, in <module>
scat1,=ax1.scatter(ex_x, ex_y, s=10, c='r', label='Fire Exclusion')
TypeError: 'PathCollection' object is not iterable
I had the same error. I found out that you shouldn't include the comma after your variable names. So try
instead of
This is because axes.scatter returns a PathCollection unlike axes.plot which returns a tuple of the lines plotted (see http://matplotlib.org/1.3.1/users/pyplot_tutorial.html#controlling-line-properties and Python code. Is it comma operator?).
So for your lines you will still need the comma because you are unpacking the tuple but for the scatter you should not have the comma.