Replace numpy matrix elements with submatrices

2019-04-13 19:47发布

问题:

Given that I have a square matrix of indices, such as:

idxs = np.array([[1, 1],
                 [0, 1]])

and an array of square matrices of the same size as each other (not necessarily the same size as idxs):

mats = array([[[ 0. ,  0. ],
               [ 0. ,  0.5]],

              [[ 1. ,  0.3],
               [ 1. ,  1. ]]])

I'd like to replace each index in idxs with the corresponding matrix in mats, to obtain:

array([[ 1. ,  0.3,  1. ,  0.3],
       [ 1. ,  1. ,  1. ,  1. ],
       [ 0. ,  0. ,  1. ,  0.3],
       [ 0. ,  0.5,  1. ,  1. ]])

mats[idxs] gives me a nested version of this:

array([[[[ 1. ,  0.3],
         [ 1. ,  1. ]],

        [[ 1. ,  0.3],
         [ 1. ,  1. ]]],


       [[[ 0. ,  0. ],
         [ 0. ,  0.5]],

        [[ 1. ,  0.3],
         [ 1. ,  1. ]]]])

and so I tried using reshape, but 'twas in vain! mats[idxs].reshape(4,4) returns:

array([[ 1. ,  0.3,  1. ,  1. ],
       [ 1. ,  0.3,  1. ,  1. ],
       [ 0. ,  0. ,  0. ,  0.5],
       [ 1. ,  0.3,  1. ,  1. ]])

If it helps, I found that skimage.util.view_as_blocks is the exact inverse of what I need (it can convert my desired result into the nested, mats[idxs] form).

Is there a (hopefully very) fast way to do this? For the application, my mats will still have just a few small matrices, but my idxs will be a square matrix of up to order 2^15, in which case I'll be replacing over a million indices to create a new matrix of order 2^16.

Thanks so much for your help!

回答1:

We are indexing into the first axis of the input array with those indices. To get the 2D output, we just need to permute axes and reshape afterwards. Thus, an approach would be with np.transpose/np.swapaxes and np.reshape, like so -

mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])

Sample run -

In [83]: mats
Out[83]: 
array([[[1, 1],
        [7, 1]],

       [[6, 6],
        [5, 8]],

       [[7, 1],
        [6, 0]],

       [[2, 7],
        [0, 4]]])

In [84]: idxs
Out[84]: 
array([[2, 3],
       [0, 3],
       [1, 2]])

In [85]: mats[idxs].swapaxes(1,2).reshape(-1,mats.shape[-1]*idxs.shape[-1])
Out[85]: 
array([[7, 1, 2, 7],
       [6, 0, 0, 4],
       [1, 1, 2, 7],
       [7, 1, 0, 4],
       [6, 6, 7, 1],
       [5, 8, 6, 0]])

Performance boost with np.take for repeated indices

With repeated indices, for performance we are better off using np.take by indexing along axis=0. Let's list out both these approaches and time it with idxs having many repeated indices.

Function definitions -

def simply_indexing_based(mats, idxs):
    ncols = mats.shape[-1]*idxs.shape[-1]
    return mats[idxs].swapaxes(1,2).reshape(-1,ncols)

def take_based(mats, idxs):np.take(mats,idxs,axis=0)
    ncols = mats.shape[-1]*idxs.shape[-1]
    return np.take(mats,idxs,axis=0).swapaxes(1,2).reshape(-1,ncols)

Runtime test -

In [156]: mats = np.random.randint(0,9,(10,2,2))

In [157]: idxs = np.random.randint(0,10,(1000,1000))
                 # This ensures many repeated indices

In [158]: out1 = simply_indexing_based(mats, idxs)

In [159]: out2 = take_based(mats, idxs)

In [160]: np.allclose(out1, out2)
Out[160]: True

In [161]: %timeit simply_indexing_based(mats, idxs)
10 loops, best of 3: 41.2 ms per loop

In [162]: %timeit take_based(mats, idxs)
10 loops, best of 3: 27.3 ms per loop

Thus, we are seeing an overall improvement of 1.5x+.

Just to get a sense of the improvement with np.take, let's time the indexing part alone -

In [168]: %timeit mats[idxs]
10 loops, best of 3: 22.8 ms per loop

In [169]: %timeit np.take(mats,idxs,axis=0)
100 loops, best of 3: 8.88 ms per loop

For those datasizes, its 2.5x+. Not bad!