Let say we have a 2-D array like this:
>>> a
array([[1, 1, 2],
[0, 2, 2],
[2, 2, 0],
[0, 2, 0]])
For each line I want to replace each element by the maximum of the 2 others in the same line.
I've found how to do it for each column separately, using numpy.amax and an identity array, like this:
>>> np.amax(a*(1-np.eye(3)[0]), axis=1)
array([ 2., 2., 2., 2.])
>>> np.amax(a*(1-np.eye(3)[1]), axis=1)
array([ 2., 2., 2., 0.])
>>> np.amax(a*(1-np.eye(3)[2]), axis=1)
array([ 1., 2., 2., 2.])
But I would like to know if there is a way to avoid a for loop and get directly the result which in this case should look like this:
>>> numpy_magic(a)
array([[2, 2, 1],
[2, 2, 2],
[2, 2, 2],
[2, 0, 2]])
Edit: after a few hours playing in the console, I've finally come up with the solution I was looking for. Be ready for some mind blowing one line code:
np.amax(a[[range(a.shape[0])]*a.shape[1],:][(np.eye(a.shape[1]) == 0)[:,[range(a.shape[1])*a.shape[0]]].reshape(a.shape[1],a.shape[0],a.shape[1])].reshape((a.shape[1],a.shape[0],a.shape[1]-1)),axis=2).transpose()
array([[2, 2, 1],
[2, 2, 2],
[2, 2, 2],
[2, 0, 2]])
Edit2: Paul has suggested a much more readable and faster alternative which is:
np.max(a[:, np.where(~np.identity(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)
After timing these 3 alternatives, both Paul's solutions are 4 times faster in every contexts (I've benchmarked for 2, 3 and 4 columns with 200 rows). Congratulations for these amazing pieces of code!
Last Edit (sorry): after replacing np.identity with np.eye which is faster, we now have the fastest and most concise solution:
np.max(a[:, np.where(~np.eye(a.shape[1], dtype=bool))[1].reshape(a.shape[1], -1)], axis=-1)
Similar to @Ethan's answer but with
np.delete()
,np.max()
, andnp.dstack()
:delete()
"filters" out each column successively;max()
finds the row-wise maximum of the remaining two columnsdstack()
stacks the resulting 1d arraysIf you have more than 3 columns, note that this will find the maximum of "all other" columns rather than the "2-greatest" columns per row. For example:
List comprehension solution.
Here are two solutions, one that is specifically designed for
max
and a more general one that works for other operations as well.Using the fact that all except possibly one maximums in each row are the maximum of the entire row, we can use
argpartition
to cheaply find the indices of the largest two elements. Then in the position of the largest we put the value of the second largest and everywhere else the largest value. Works also for more than 3 columns.This solution depends on specific properties of max.
A more general solution that for example also works for
sum
instead ofmax
would be. Glue two copies ofa
together (side-by-side, not on top of each other). So the rows are something likea0 a1 a2 a3 a0 a1 a2 a3
. For an indexx
we can get all butax
by slicing[x+1:x+4]
. To do this vectorized we usestride_tricks
: