Make legend correspond to colors of scatter points

2019-03-19 02:25发布

问题:

I have a plot that I am generating through KMeans algorithm in scikit-learn. The clusters correspond to different colors. Here is the plot,

I need a legend for this plot which corresponds to the cluster number in the plot. Ideally, the legend should display the color of the cluster and the label should be the cluster number. Thanks.

EDIT: I think I should put some code since people are downvoting this

from sklearn.cluster import KMeans
km = KMeans(n_clusters=20, init='random')   
km.fit(df)  #df is the dataframe which contains points as coordinates
labels = km.labels_
plt.clf()
fig = plt.figure()
ax = fig.add_subplot(111, axisbg='w', frame_on=True)
fig.set_size_inches(18.5, 10.5)

# Plot the clusters on the map
# m is a basemap object
m.scatter(
         [geom.x for geom in map_points],
         [geom.y for geom in map_points],
         20, marker='o', lw=.25,
         c = labels.astype(float),
         alpha =0.9, antialiased=True,
         zorder=3)
m.fillcontinents(color='#555555')
plt.show()

回答1:

I was able to make the legend correspond to the color. The key was using multiple scatterplots for each category in the data as mentioned by Rutger Kassies.

Here is the code:

import numpy as np
import matplotlib.pyplot as plt

# Setting various plot properties
plt.clf()
fig = plt.figure()
ax = fig.add_subplot(111, axisbg='w', frame_on=True)
fig.set_size_inches(18.5, 10.5)

# Creating a discrete colorbar
colors = plt.cm.rainbow(np.linspace(0, 1, 20))

current_plot_range = 0
previous_plot_range = 0

for i,c in enumerate(colors):
  previous_plot_range += current_plot_range
  current_plot_range = labels[labels==i].size
  m.scatter(
       [geom.x for geom in map_points[      
             previous_plot_range:previous_plot_range+current_plot_range]],
       [geom.y for geom in map_points[
             previous_plot_range:previous_plot_range+current_plot_range]],
       20, lw=.25, marker='o',color = c, label=i, alpha =0.9, antialiased=True, 
       zorder=3)

plt.legend()
m.fillcontinents(color='#555555')

The result looks something like this: