Representing 4D data in mplot 3D using colormaps

2019-01-27 01:30发布

问题:

Is there a way to change the value that the colormap is tied to in an mplot3d surface plot?
As an example, I'm trying to represent surface temperature for an object:

import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

z = np.array([0,1,2,3,4,5,6,7,8,9,10])
radius = np.array([0,1,1.5,1,0,2,4,5,4,2,1])
temp = np.array([150,200,210,220,225,220,195,185,160,150,140])

angle = np.linspace(0,2*np.pi,20)
Z,ANG = np.meshgrid(z,angle)
# transform them to cartesian system
X,Y = radius*np.cos(ANG),radius*np.sin(ANG)

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap='jet')
plt.show()

This generates a 3d representation of the object, but the colormap is by default tied to the z-axis value. Can the colormap be tied to the 'temp' value?

(in this example, 'temp' maps on to Z the same way that the 'radius' values do)

I'm aware of tools like MayaVI, but if it's possible I'm hoping for a solution within matplotlib.

回答1:

Try using facecolors in the call to plot_surface:

import matplotlib.pyplot as plt
import numpy as np

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

z = np.array([0,1,2,3,4,5,6,7,8,9,10])
radius = np.array([0,1,1.5,1,0,2,4,5,4,2,1])
temp = np.array([150,200,210,220,225,220,195,185,160,150,140])

angle = np.linspace(0,2*np.pi,20)
Z,ANG = np.meshgrid(z,angle)
T,ANG = np.meshgrid(temp,angle)
# transform them to cartesian system
X,Y = radius*np.cos(ANG),radius*np.sin(ANG)

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=cm.jet(T/float(T.max())))
plt.show()



回答2:

In case you are looking for a colorbar too do the following

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
p_surf=ax.plot_surface(x,y,z,rstride=1,cstride=1,linewidth=0,antialiased=True,facecolors=cm.jet(np.sqrt(x*x + y*y + z*z)))
m = cm.ScalarMappable(cmap=cm.jet)
m.set_array(x*x + y*y + z*z)
plt.colorbar(m)
plt.show()