Access n-th dimension in python

2019-01-25 22:32发布

I want a easy to read access to some parts of a multidimensional numpy array. For any array accessing the first dimension is easy (b[index]). Accessing the sixth dimension on the other hand is "hard" (especially to read).

b[:,:,:,:,:,index] #the next person to read the code will have to count the :

Is there a better way to do this? Especially is there a way, where the axis is not known while writing the program?

Edit: The indexed dimension is not necessarily the last dimension

5条回答
看我几分像从前
2楼-- · 2019-01-25 22:35

An intermediate way (in readability and time) between the answers of MSeifert and kazemakase is using np.rollaxis:

np.rollaxis(b, axis=5)[index]

Testing the solutions:

import numpy as np

arr = np.random.random((10, 10, 10, 10, 10, 10, 10))

np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr.take(4, axis=5))
np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr[(slice(None), )*5 + (4, )])
np.testing.assert_array_equal(arr[:,:,:,:,:,4], np.rollaxis(arr, 5)[4])

%timeit arr.take(4, axis=5)
# 100 loops, best of 3: 4.44 ms per loop
%timeit arr[(slice(None), )*5 + (4, )]
# 1000000 loops, best of 3: 731 ns per loop
%timeit arr[:, :, :, :, :, 4]
# 1000000 loops, best of 3: 540 ns per loop
%timeit np.rollaxis(arr, 5)[4]
# 100000 loops, best of 3: 3.41 µs per loop
查看更多
霸刀☆藐视天下
3楼-- · 2019-01-25 22:37

If you want a view and want it fast you can just create the index manually:

arr[(slice(None), )*5 + (your_index, )]
#                   ^---- This is equivalent to 5 colons: `:, :, :, :, :`

Which is much faster than np.take and only marginally slower than indexing with :s:

import numpy as np

arr = np.random.random((10, 10, 10, 10, 10, 10, 10))

np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr.take(4, axis=5))
np.testing.assert_array_equal(arr[:,:,:,:,:,4], arr[(slice(None), )*5 + (4, )])
%timeit arr.take(4, axis=5)
# 18.6 ms ± 249 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit arr[(slice(None), )*5 + (4, )]
# 2.72 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr[:, :, :, :, :, 4]
# 2.29 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

But maybe not as readable, so if you need that often you probably should put it in a function with a meaningful name:

def index_axis(arr, index, axis):
    return arr[(slice(None), )*axis + (index, )]

np.testing.assert_array_equal(arr[:,:,:,:,:,4], index_axis(arr, 4, axis=5))

%timeit index_axis(arr, 4, axis=5)
# 3.79 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
查看更多
ゆ 、 Hurt°
4楼-- · 2019-01-25 22:39

You can use np.take. For example:

b.take(index, axis=5)
查看更多
贪生不怕死
5楼-- · 2019-01-25 22:39

In the spirit of @Jürg Merlin Spaak's rollaxis but much faster and not deprecated:

b.swapaxes(0, axis)[index]
查看更多
不美不萌又怎样
6楼-- · 2019-01-25 22:46

You can say:

slice = b[..., index]
查看更多
登录 后发表回答