draw many spheres efficiently

2019-01-20 15:47发布

问题:

I need to draw many spheres, small and large, in one picture. The following code works, but takes awfully long to run.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy

fig = plt.figure()
ax = fig.gca(projection='3d')
ax.set_aspect('equal')

u = numpy.linspace(0, 2*numpy.pi, 100)
v = numpy.linspace(0, numpy.pi, 100)
x = numpy.outer(numpy.cos(u), numpy.sin(v))
y = numpy.outer(numpy.sin(u), numpy.sin(v))
z = numpy.outer(numpy.ones(numpy.size(u)), numpy.cos(v))

for k in range(200):
    c = numpy.random.rand(3)
    r = numpy.random.rand(1)
    ax.plot_surface(
        r*x + c[0], r*y + c[1], r*z + c[2],
        color='#1f77b4',
        alpha=0.5,
        linewidth=0
        )

plt.show()

I'm looking for a more efficient solution. Perhaps there is a native sphere artist in matplotlib that I didn't find?

回答1:

No, there is no such thing as a "sphere artist". And even if there was, it would not take less time to draw it.

The solution you present in the question is a sensible way to draw many spheres. However, you might want to consider using a lot less points on the sphere,

u = numpy.linspace(0, 2*numpy.pi, 12)
v = numpy.linspace(0, numpy.pi, 7)

An option one should always consider is not to use matplotlib for 3D plotting, as it is not actually been designed for it; and use Mayavi instead. The above in mayavi would look like

from mayavi import mlab
import numpy as np

[phi,theta] = np.mgrid[0:2*np.pi:12j,0:np.pi:12j]
x = np.cos(phi)*np.sin(theta)
y = np.sin(phi)*np.sin(theta)
z = np.cos(theta)

def plot_sphere(p):
    r,a,b,c = p
    return mlab.mesh(r*x+a, r*y+b, r*z+c)  


for k in range(200):
    c = np.random.rand(4)
    c[0] /= 10.
    plot_sphere(c)

mlab.show()

While the calculation takes a similar time, interactively zooming or panning is much faster in Mayavi.

Furthermore Mayavi actually provides something like a "sphere artist", which is called points3d

from mayavi import mlab
import numpy as np

c = np.random.rand(200,3)
r = np.random.rand(200)/10.

mlab.points3d(c[:,0],c[:,1],c[:,2],r)

mlab.show()