Hacker News new | ask | show | jobs
by gregjm 93 days ago
> I don't know JAX well enough to explain exactly why it's 3x faster than NumPy on the same matrix multiplications.

JAX is basically a frontend for the XLA compiler, as you note. The secret sauce is two insights - 1) if you have enough control, you can modify the layout of tensor computations and permute them so they don’t have to match that of the input program but have a more favorable memory access pattern; 2) most things are memory bound, so XLA creates fusion kernels that combine many computations together between memory accesses. I don’t know if the Apple BLAS library has fused kernels with GEMM + some output layer, but XLA is capable of writing GEMM fusions and might pick them if they autotune faster on given input/output shapes.

> But I haven't verified that in detail. Might be time to learn.

If you set the environment variable XLA_FLAGS=--dump_hlo_to=$DIRECTORY then you’ll find out! There will be a “custom-call” op if it’s dispatching to BLAS, otherwise it will have a “dot” op in the post-optimization XLA HLO for the module. See the docs:

https://openxla.org/xla/hlo_dumps