I'm struggling to understand exactly how einsum
works. I've looked at the documentation and a few examples, but it's not seeming to stick.
Here's an example we went over in class:
C = np.einsum("ij,jk->ki", A, B)
for two arraysA
and B
I think this would take A^T * B
, but I'm not sure (it's taking the transpose of one of them right?). Can anyone walk me through exactly what's happening here (and in general when using einsum
)?
I found NumPy: The tricks of the trade (Part II) instructive
Notice there are three axes, i, j, k, and that j is repeated (on the left-hand-side).
i,j
represent rows and columns fora
.j,k
forb
.In order to calculate the product and align the
j
axis we need to add an axis toa
. (b
will be broadcast along(?) the first axis)j
is absent from the right-hand-side so we sum overj
which is the second axis of the 3x3x3 arrayFinally, the indices are (alphabetically) reversed on the right-hand-side so we transpose.
(Note: this answer is based on a short blog post about
einsum
I wrote a while ago.)What does
einsum
do?Imagine that we have two multi-dimensional arrays,
A
andB
. Now let's suppose we want to...A
withB
in a particular way to create new array of products; and then maybeThere's a good chance that
einsum
will help us do this faster and more memory-efficiently that combinations of the NumPy functions likemultiply
,sum
andtranspose
will allow.How does
einsum
work?Here's a simple (but not completely trivial) example. Take the following two arrays:
We will multiply
A
andB
element-wise and then sum along the rows of the new array. In "normal" NumPy we'd write:So here, the indexing operation on
A
lines up the first axes of the two arrays so that the multiplication can be broadcast. The rows of the array of products is then summed to return the answer.Now if we wanted to use
einsum
instead, we could write:The signature string
'i,ij->i'
is the key here and needs a little bit of explaining. You can think of it in two halves. On the left-hand side (left of the->
) we've labelled the two input arrays. To the right of->
, we've labelled the array we want to end up with.Here is what happens next:
A
has one axis; we've labelled iti
. AndB
has two axes; we've labelled axis 0 asi
and axis 1 asj
.By repeating the label
i
in both input arrays, we are tellingeinsum
that these two axes should be multiplied together. In other words, we're multiplying arrayA
with each column of arrayB
, just likeA[:, np.newaxis] * B
does.Notice that
j
does not appear as a label in our desired output; we've just usedi
(we want to end up with a 1D array). By omitting the label, we're tellingeinsum
to sum along this axis. In other words, we're summing the rows of the products, just like.sum(axis=1)
does.That's basically all you need to know to use
einsum
. It helps to play about a little; if we leave both labels in the output,'i,ij->ij'
, we get back a 2D array of products (same asA[:, np.newaxis] * B
). If we say no output labels,'i,ij->
, we get back a single number (same as doing(A[:, np.newaxis] * B).sum()
).The great thing about
einsum
however, is that is does not build a temporary array of products first; it just sums the products as it goes. This can lead to big savings in memory use.A slightly bigger example
To explain the dot product, here are two new arrays:
We will compute the dot product using
np.einsum('ij,jk->ik', A, B)
. Here's a picture showing the labelling of theA
andB
and the output array that we get from the function:You can see that label
j
is repeated - this means we're multiplying the rows ofA
with the columns ofB
. Furthermore, the labelj
is not included in the output - we're summing these products. Labelsi
andk
are kept for the output, so we get back a 2D array.It might be even clearer to compare this result with the array where the label
j
is not summed. Below, on the left you can see the 3D array that results from writingnp.einsum('ij,jk->ijk', A, B)
(i.e. we've kept labelj
):Summing axis
j
gives the expected dot product, shown on the right.Some exercises
To get more of feel for
einsum
, it can be useful to implement familiar NumPy array operations using the subscript notation. Anything that involves combinations of multiplying and summing axes can be written usingeinsum
.Let A and B be two 1D arrays with the same length. For example,
A = np.arange(10)
andB = np.arange(5, 15)
.The sum of
A
can be written:Element-wise multiplication,
A * B
, can be written:The inner product or dot product,
np.inner(A, B)
ornp.dot(A, B)
, can be written:The outer product,
np.outer(A, B)
, can be written:For 2D arrays,
C
andD
, provided that the axes are compatible lengths (both the same length or one of them of has length 1), here are a few examples:The trace of
C
(sum of main diagonal),np.trace(C)
, can be written:Element-wise multiplication of
C
and the transpose ofD
,C * D.T
, can be written:Multiplying each element of
C
by the arrayD
(to make a 4D array),C[:, :, None, None] * D
, can be written:Lets make 2 arrays, with different, but compatible dimensions to highlight their interplay
Your calculation, takes a 'dot' (sum of products) of a (2,3) with a (3,4) to produce a (4,2) array.
i
is the 1st dim ofA
, the last ofC
;k
the last ofB
, 1st ofC
.j
is 'consumed' by the summation.This is the same as
np.dot(A,B).T
- it's the final output that's transposed.To see more of what happens to
j
, change theC
subscripts toijk
:This can also be produced with:
That is, add a
k
dimension to the end ofA
, and ani
to the front ofB
, resulting in a (2,3,4) array.0 + 4 + 16 = 20
,9 + 28 + 55 = 92
, etc; Sum onj
and transpose to get the earlier result:Grasping the idea of
numpy.einsum()
is very easy if you understand it intuitively. As an example case, let's start with a simple description involving matrix multiplication.To use
numpy.einsum()
, you have to pass the so-called subscripts string as an argument, followed by your input arrays.Let's say you have two 2D arrays,
A
andB
, and you want to do matrix multiplication. So, you do:Here the subscript string
ij
corresponds to arrayA
while the subscript stringjk
corresponds to arrayB
. Also, the most important thing to note here is that the number of characters in each subscript string must match the dimensions of the array. (i.e. two chars for 2D arrays, three chars for 3D arrays, and so on.) And if you repeat the chars between subscript strings (j
in our case), then that means you want theein
sum to happen along those dimensions. Thus, they will be sum-reduced. (i.e. that dimension will be gone)The subscript string after this
->
, will be our resultant array. If you leave it empty, then everything will be summed and a scalar value is returned as result. Else the resultant array will have dimensions according to the subscript string. In our example, it'll beik
. This is intuitive because we know that for matrix multiplication the number of columns in arrayA
has to match the number of rows in arrayB
which is what is happening here (i.e. we encode this knowledge by repeating the charj
in the subscript string)Here are some more examples illustrating the use of
np.einsum()
in implementing some common tensor or nd-array operations.Inputs
1) Matrix multiplication (similar to
np.matmul(arr1, arr2)
)2) Extract elements along the main-diagonal (similar to
np.diag(arr)
)3) Hadamard product (i.e. element-wise product of two arrays) (similar to
arr1 * arr2
)4) Element-wise squaring (similar to
np.square(arr)
orarr ** 2
)5) Trace (i.e. sum of main-diagonal elements) (similar to
np.trace(arr)
)6) Matrix transpose (similar to
np.transpose(arr)
)7) Outer Product (of vectors) (similar to
np.outer(vec1, vec2)
)8) Inner Product (of vectors) (similar to
np.inner(vec1, vec2)
)9) Sum along axis 0 (similar to
np.sum(arr, axis=0)
)10) Sum along axis 1 (similar to
np.sum(arr, axis=1)
)11) Batch Matrix Multiplication
12) Sum along axis 2 (similar to
np.sum(arr, axis=2)
)13) Sum all the elements in array (similar to
np.sum(arr)
)14) Sum over multiple axes (i.e. marginalization)
(similar to
np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))
)15) Double Dot Products (similar to np.sum(hadamard-product) cf. 3)
16) 2D and 3D array multiplication
Such a multiplication could be very useful when solving linear system of equations (Ax = b) where you want to verify the result.
On the contrary, if one has to use
np.matmul()
for this verification, we have to do couple ofreshape
s to achieve this like:Bonus: Read more math here : Einstein-Summation and definitely here: Tensor-Notation