|
|
|
|
|
by cschmid
1532 days ago
|
|
I just tried replicating the same experiment using Jax's numpy API, and einsum is still slower, but at least the same order of magnitude: %timeit (x_jax @ y_jax).block_until_ready()
579 µs ± 4.54 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax, optimize=True).block_until_ready()
658 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit jnp.einsum('bik,bkj->bij',x_jax,y_jax).block_until_ready()
660 µs ± 2.82 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
|
|