|
|
|
|
|
by gregjm
93 days ago
|
|
> I don't know JAX well enough to explain exactly why it's 3x faster than NumPy on the same matrix multiplications. JAX is basically a frontend for the XLA compiler, as you note. The secret sauce is two insights - 1) if you have enough control, you can modify the layout of tensor computations and permute them so they don’t have to match that of the input program but have a more favorable memory access pattern; 2) most things are memory bound, so XLA creates fusion kernels that combine many computations together between memory accesses. I don’t know if the Apple BLAS library has fused kernels with GEMM + some output layer, but XLA is capable of writing GEMM fusions and might pick them if they autotune faster on given input/output shapes. > But I haven't verified that in detail. Might be time to learn. If you set the environment variable XLA_FLAGS=--dump_hlo_to=$DIRECTORY then you’ll find out! There will be a “custom-call” op if it’s dispatching to BLAS, otherwise it will have a “dot” op in the post-optimization XLA HLO for the module. See the docs: https://openxla.org/xla/hlo_dumps |
|