Matplotlib colormap bug with length-4 arrays

2019-02-25 06:45发布

问题:

I have some arrays that I need to plot in a loop with a certain colormap. However, one of my arrays is length-4, and I run into this problem:

import numpy as np
import matplotlib as plt

ns = range(2,8)
cm = plt.cm.get_cmap('spectral')
cmap = [cm(1.*i/len(ns)) for i in range(len(ns))]
for i,n in enumerate(ns):
    x = np.linspace(0, 10, num=n)
    y = np.zeros(n) + i
    plt.scatter(x, y, c=cmap[i], edgecolor='none', s=50, label=n)
plt.legend(loc='lower left')
plt.show()

For n=4, it looks like Matplotlib is applying each element of the cmap RGBA-tuple to each value of the array. For the other length arrays, the behavior is expected.

Now, I actually have a much more complicated code and do not want to spend time rewriting the loop. Is there a workaround for this?

回答1:

It looks like you've bumped into an unfortunate API design in the handling of the c argument. One way to work around the problem is to make c an array with shape (len(x), 4) containing len(x) copies of the desired color. E.g.

ns = range(2,8)
cm = plt.cm.get_cmap('spectral')
cmap = [cm(1.*i/len(ns)) for i in range(len(ns))]
for i,n in enumerate(ns):
    x = np.linspace(0, 10, num=n)
    y = np.zeros(n) + i
    c = np.tile(cmap[i], (len(x), 1))
    plt.scatter(x, y, c=c, edgecolor='none', s=50, label=n)
plt.legend(loc='lower left')
plt.show()

Another alternative is to convert the RBG values into a hex string, and pass the alpha channel of the color using the alpha argument. As @ali_m pointed out in a comment, the function matplotlib.colors.rgb2hex makes this easy. If you know the alpha channel of the color is always 1.0, you can remove the code that creates the alpha argument.

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

ns = range(2,8)
cm = plt.cm.get_cmap('spectral')
cmap = [cm(1.*i/len(ns)) for i in range(len(ns))]
for i,n in enumerate(ns):
    x = np.linspace(0, 10, num=n)
    y = np.zeros(n) + i
    c = mpl.colors.rgb2hex(cmap[i])
    alpha = cmap[i][3]
    plt.scatter(x, y, c=c, edgecolor='none', s=50, label=n, alpha=alpha)
plt.legend(loc='lower left')
plt.show()