Hacker News new | ask | show | jobs
by cschmid 1532 days ago
If I understand https://github.com/numpy/numpy/blob/v1.22.0/numpy/core/einsu... and https://github.com/numpy/numpy/blob/v1.22.0/numpy/core/src/m... correctly, using einsum without the optimize flag seems to use a for loop in C to do the multiplication.

The optimizer clearly tries to improve the performance, but in many cases, it doesn't seem to change anything. Let's simply multiply some matrices:

  x, y = np.random.rand(200, 200, 200), np.random.rand(200, 200, 200)
I can do

  %timeit x@y
  40.3 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
or a naive

  %timeit np.einsum('bik,bkj->bij',x,y)
  1.53 s ± 21.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
But even with optimization, I see

  %timeit np.einsum('bik,bkj->bij',x,y, optimize=True)
  1.54 s ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I'm not sure if I'm doing something wrong.
2 comments

The looped matrix multiply that you show is very hard to optimize for in the general case of einsum. Often the looped GEMM is found permuted such as `kbi,kjb->bij`. In this case, heuristics are needed to determine if GEMM is worth it due to unaligned memory copies.

`optimize=True` is generally best when there are more than two tensors in the expression.

I just tried replicating the same experiment using Jax's numpy API, and einsum is still slower, but at least the same order of magnitude:

  %timeit (x_jax @ y_jax).block_until_ready()
  579 µs ± 4.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

  %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax, optimize=True).block_until_ready()
  658 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

  %timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax).block_until_ready()
  660 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)