seaborn heatmap get array of color codes values

2019-06-11 13:22发布

问题:

I am trying to get the color codes associated with each cell of a heatmap:

import seaborn as sns
import numpy as np
import matplotlib.cm as cm

hm = sns.heatmap(
np.random.randn(10,10),
cmap = cm.coolwarm)

# hm.<some function>[0][0] would return the color code of the cell indexed (0,0)

回答1:

Because sns.heatmap returns a matplotlib axis object, we can't really use hm directly. But we can use the cmap object itself to return the rgba values of the data. Edit Code has been updated to include normalization of data.

from matplotlib.colors import Normalize

data = np.random.randn(10, 10)
cmap = cm.get_cmap('Greens')
hm = sns.heatmap(data, cmap=cmap)

# Normalize data
norm = Normalize(vmin=data.min(), vmax=data.max())
rgba_values = cmap(norm(data))

All of the colors are now contained in rgba_values. So to get the color of the upper left square in the heatmap you could simply do

In [13]: rgba_values[0,0]
Out[13]: array([ 0.        ,  0.26666668,  0.10588235,  1.        ])

For more, check out Getting individual colors from a color map in matplotlib


Update

To readjust the colormap from using the center and robust keywords in the call to sns.heatmap, you basically just have to redefine vmin and vmax. Looking at the relevant seaborn source code (http://github.com/mwaskom/seaborn/blob/master/seaborn/matrix.py#L202), the below changes to vmin and vmax should do the trick.

data = np.random.randn(10, 10)
center = 2
robust = False
cmap = cm.coolwarm
hm = sns.heatmap(data, cmap=cmap, center=center, robust=robust)

vmin = np.percentile(data, 2) if robust else data.min()
vmax = np.percentile(data, 98) if robust else data.max()
vmin += center
vmax += center

norm = Normalize(vmin=vmin, vmax=vmax)
rgba_values = cmap(norm(data))


回答2:

Without any knowledge on the input data and arguments of heatmap you can get the colors from the underlying QuadMesh, knowing that the heatmap should be the first and only collection inside the axes that is returned by heatmap.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

data = np.array([[0,-2],[10,5]])
ax = sns.heatmap(data, center=0, cmap="bwr", robust=False)
im = ax.collections[0]
rgba_values = im.cmap(im.norm(im.get_array()))

Also see this answer. In contrast to AxesImage, QuadMesh returns a list of colors. Hence the above code will give you a 2D array where the columns are the RGBA color channels. If you need a 3D output with the first two dimensions being the same as the input data you would need to reshape

rgba_values = rgba_values.reshape((im._meshHeight, im._meshWidth, 4))