How to produce MATLAB plot (interpolation) in Matp

2020-04-14 12:12发布

问题:

I am trying to follow a MATLAB example of meshgrid + interpolation. The example code is found HERE. On that site, I am going through the following example: Example – Displaying Nonuniform Data on a Surface.

Now, I would like to produce a similar plot in Python (Numpy + Matplotlib) to what is shown there in MATLAB. This is the plot that MATLAB produces:

I am having trouble with doing this in Python. Here is my code and my output in Python 2.7:

from matplotlib.mlab import griddata
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

x = np.random.rand(200)*16 - 8
y = np.random.rand(200)*16 - 8
r = np.sqrt(x**2 + y**2)
z = np.sin(r)/r

xi = np.linspace(min(x),max(x), 100)
yi = np.linspace(min(y),max(y), 200)

X,Y = np.meshgrid(xi,yi)

Z = griddata(x, y, z, X, Y, interp='linear')

fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,cmap=cm.jet)

Here is the result of my attempt at doing this with matplotlib and NumPy..

Could someone please help me recreate the MATLAB plot in matplotlib, as either a mesh or a surface plot?

回答1:

So it seems that the major differences in the look have to do with the default number of lines plotted by matlab, which can be adjusted by increasing rstride and cstride. In terms of color, in order for the colormap to be scaled properly it is probably best in this case to set your limits, vmin and vmax because when automatically set, it will use the min and max of Z, but in this case, they are both nan, so you could use np.nanmin and np.nanmax.

from matplotlib.mlab import griddata
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

x = np.random.rand(200)*16 - 8
y = np.random.rand(200)*16 - 8
r = np.sqrt(x**2 + y**2)
z = np.sin(r)/r

xi = np.linspace(min(x),max(x), 100)
yi = np.linspace(min(y),max(y), 200)

X,Y = np.meshgrid(xi,yi)

Z = griddata(x, y, z, X, Y, interp='linear')

fig = plt.figure()
ax = fig.gca(projection='3d')

surf = ax.plot_surface(X, Y, Z, rstride=5, cstride=5, cmap=cm.jet, vmin=np.nanmin(Z), vmax=np.nanmax(Z), shade=False)
scat = ax.scatter(x, y, z)

In matplotlib unfortunately I get some annoying overlapping/'clipping' problems, where Axes3d doesn't always properly determine the order in which object should be displayed.