I am trying to get the color codes associated with each cell of a heatmap:
import seaborn as sns
import numpy as np
import matplotlib.cm as cm
hm = sns.heatmap(
np.random.randn(10,10),
cmap = cm.coolwarm)
# hm.<some function>[0][0] would return the color code of the cell indexed (0,0)
Because sns.heatmap
returns a matplotlib
axis object, we can't really use hm
directly. But we can use the cmap
object itself to return the rgba values of the data. Edit Code has been updated to include normalization of data.
from matplotlib.colors import Normalize
data = np.random.randn(10, 10)
cmap = cm.get_cmap('Greens')
hm = sns.heatmap(data, cmap=cmap)
# Normalize data
norm = Normalize(vmin=data.min(), vmax=data.max())
rgba_values = cmap(norm(data))
All of the colors are now contained in rgba_values
. So to get the color of the upper left square in the heatmap you could simply do
In [13]: rgba_values[0,0]
Out[13]: array([ 0. , 0.26666668, 0.10588235, 1. ])
For more, check out Getting individual colors from a color map in matplotlib
Update
To readjust the colormap from using the center
and robust
keywords in the call to sns.heatmap
, you basically just have to redefine vmin
and vmax
. Looking at the relevant seaborn source code (http://github.com/mwaskom/seaborn/blob/master/seaborn/matrix.py#L202), the below changes to vmin
and vmax
should do the trick.
data = np.random.randn(10, 10)
center = 2
robust = False
cmap = cm.coolwarm
hm = sns.heatmap(data, cmap=cmap, center=center, robust=robust)
vmin = np.percentile(data, 2) if robust else data.min()
vmax = np.percentile(data, 98) if robust else data.max()
vmin += center
vmax += center
norm = Normalize(vmin=vmin, vmax=vmax)
rgba_values = cmap(norm(data))
Without any knowledge on the input data and arguments of heatmap
you can get the colors from the underlying QuadMesh
, knowing that the heatmap should be the first and only collection
inside the axes that is returned by heatmap
.
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
data = np.array([[0,-2],[10,5]])
ax = sns.heatmap(data, center=0, cmap="bwr", robust=False)
im = ax.collections[0]
rgba_values = im.cmap(im.norm(im.get_array()))
Also see this answer. In contrast to AxesImage
, QuadMesh
returns a list of colors. Hence the above code will give you a 2D array where the columns are the RGBA color channels. If you need a 3D output with the first two dimensions being the same as the input data you would need to reshape
rgba_values = rgba_values.reshape((im._meshHeight, im._meshWidth, 4))