loop over 2d subplot as if it's a 1-D

2020-07-06 06:30发布

问题:

I'm trying to plot many data using subplots and I'm NOT in trouble but I'm wondering if there is a convenience method to do this.

below is the sample code.

import numpy as np    
import math 
import matplotlib.pyplot as plt

quantities=["sam_mvir","mvir","rvir","rs","vrms","vmax"
,"jx","jy","jz","spin","m200b","m200c","m500c","m2500c"
,"xoff","voff","btoc","ctoa","ax","ay","az"]

# len(quantities) = 21, just to make the second loop expression 
# shorter in this post.

ncol = 5
nrow = math.ceil(21 / ncol)

fig, axes = plt.subplots(nrows = nrow, ncols=ncol, figsize=(8,6))

for i in range(nrow):
    for j in range(((21-i*5)>5)*5 + ((21-i*5)<5)*(21%5)):
        axes[i, j].plot(tree[quantities[i*ncol + j]]) 
        axes[i, j].set_title(quantities[i*ncol + j])

This code loops over a 2D array of subplots and stops at the 21st plot leaving 4 panels empty. My question is that, is there any built-in method to do this task? For example, make 2D subplot array and "flatten" the array into 1D then loop over 1D array through 0 to 20.

The expression in the second range() is very ugly. I don't think I'm going to use this code. I think the trivial way is to count the number of plots and break if count > 21. But I just wonder if there is a better (or fancy) way.

回答1:

Rather than creating your subplots in advance using plt.subplots, just create them as you go using plt.subplot(nrows, ncols, number). The small example below shows how to do it. It's created a 3x3 array of plots and only plotted the first 6.

import numpy as np
import matplotlib.pyplot as plt

nrows, ncols = 3, 3

x = np.linspace(0,10,100)

fig = plt.figure()    
for i in range(1,7):
    ax = fig.add_subplot(nrows, ncols, i)
    ax.plot(x, x**i)

plt.show()

You could fill the final three in of course by doing plt.subplot(nrows, ncols, i) but not calling any plotting in there (if that's what you wanted).

import numpy as np
import matplotlib.pyplot as plt

nrows, ncols = 3, 3

x = np.linspace(0,10,100)

fig = plt.figure()    
for i in range(1,10):
    ax = fig.add_subplot(nrows, ncols, i)
    if i < 7:
        ax.plot(x, x**i)

plt.show()

You may also like the look of GridSpec.



回答2:

subplots returns an ndarray of axes objects, you can just flatten or ravel it:

fig, axes = plt.subplots(nrows = nrow, ncols=ncol, figsize=(8,6))
for ax in axes.flatten()[:20]:
    # do stuff to ax