How to add a legend to scatterplot

2019-08-19 18:15发布

I am doing an exercise for a Machine Learning course. I appended to a matrix a dataset of images in form of arrays into datamatrix, then I standardized it and then computed the principal components. Labels is an array containing for each image the label (that was the subdirectory containing it) I need to visualize pairs of principal components, in this part the first two. The suggestion from the professor was to use the matplotli.scatter function, I found the seaborn.scatterplot function that seems better, but with none of the two I managed to put a legend with the labels names onto it.

pca = PCA()
X_t = pca.fit_transform(datamatrix)
X_r = pca.inverse_transform(X_t)

plt.figure(figsize=(25,5))

colours = ['r','g','b','p']
plt.subplot(1, 3, 1)
sns.scatterplot(X_t[:,0], X_t[:,1], hue=labels, palette=colours, legend='full')
plt.title('PC 1 and 2')

I am new to Python and to Machine Learnings libaries

Edit: As suggested I tried modifying the cod:

data = {"x" : X_t[:,0], "y" : X_t[:,1], "label" : labels}
sns.scatterplot(x="x", y="y", hue="label", palette=colours, data=data, legend='full')

But I obtain the same result: I have the legend, but without the name of the labels capture

2条回答
我欲成王,谁敢阻挡
2楼-- · 2019-08-19 18:18

Before showing the plot add the legend using:

plt.legend()
查看更多
劳资没心,怎么记你
3楼-- · 2019-08-19 18:43

Seaborn scatterplot will automatically create a legend as shown in the second example from the documentation. It does however require to have the data in a dictionary-like structure, which is common for pandas dataframes.

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

data = {"x" : np.random.rand(10),
        "y" : np.random.rand(10),
        "label" : np.random.choice(["Label 1", "Label 2"], size=10)}

sns.scatterplot(x="x", y="y", hue="label", data=data)
plt.show()

enter image description here

To achieve the same via matplotlib's scatter you would need to create the legend yourself, which is indeed a bit more cumbersome, but may be helpful for understanding.

import numpy as np
import matplotlib.pyplot as plt

data = {"x" : np.random.rand(10),
        "y" : np.random.rand(10),
        "label" : np.random.choice(["Label 1", "Label 2"], size=10)}

labels, inv = np.unique(data["label"], return_inverse=True)
scatter = plt.scatter(x="x", y="y", c = inv, data=data)

handles = [plt.Line2D([],[],marker="o", ls="", 
                      color=scatter.cmap(scatter.norm(yi))) for yi in np.unique(inv)]
plt.legend(handles, labels)

plt.show()

enter image description here

Also see Add legend to scatter plot (PCA)

查看更多
登录 后发表回答