EinSum : Einstein Sum; Making Deep Learning better
Its not from a mathematician’s POV but from an ML engineer’s!
Einstein summation, or index notation, is a compact mathematical notation for combining products and sums in a general way. It simplifies the expression of complex equations involving vectors, matrices and tensors. More specifically, it is a concise way to express summations over repeated indices in math expressions involving vectors, making multidimensional equations easier to manipulate.
Einstein’s notation has multiple rules but the key ones are:
- Repeating letters between input arrays means that values along those axes will be multiplied together
- Omitting a letter from the output means that values along that axis will be summed.
SO DOES THIS MAKE CODE BETTER?
Let’s do basic matrix multiplication using different approaches to compare the difference this notation makes:
Using three loops : `CPU times: user 727 ms, sys: 155 µs, total: 727 ms Wall time: 736 ms`
let’s take it up a notch and reduce the loops to two using the dot product
Using two loops : `CPU times: user 2.89 ms, sys: 0 ns, total: 2.89 ms Wall time: 3.88 ms`
Broadcasting might make it faster. Don’t you think?
Using single loop : `CPU times: user 1.71 ms, sys: 0 ns, total: 1.71 ms Wall time: 1.48 ms`
Now, for Einstein's Summation, we would be using torch.einsum
Using einSum loop : `CPU times: user 240 µs, sys: 0 ns, total: 240 µs Wall time: 259 µs`
So, it might not be a huge sample size to give contrastive results but you get the gist, don’t you?
A BIT EXPLANATION:
Here
i , k and j
are dimension values wherei = 5, j = 784, k = 10 based on input shapes.` In torch
.einsum(‘ij,jk->ik’, a,b),
`ij` represents shape of a: (m1) : (5 , 784)
`jk` represents shape of b: (m2) : (784 , 10)
and resultant, after arrow(->) shows output shape,
`ik` represents shape of c: (res) : (5 , 10)
That’s all the whole index thing resides around. But, for a little better understanding, let’s get into the details of those two critical rules for a bit:
- Repeating letters between input arrays means that values along those axes will be multiplied together:
What would happen if I don't omit any letter from the resultant `ijk` i.e.
torcheinsum(‘ij,jk -> ijk’ ,a,b)
this new k i.e. 784 is just an elementwise product of elements of prior repeating axis k of a and b
- Omitting a letter from the output means that values along that axis will be summed.
so if I use sum() on the j axis, it gets eliminated and the resultant is the same as doing
‘ij,jk -> ik’.
So, your output would be molded into the form given on right side of the arrow.
TIDBIT:
If you left the right side of the arrow empty, the output of einsum would be a scalar(1x1).
IDK why that might be. Do you?