JAX has a sub-system called Pallas[1] with a Triton-like programming model and an example implementation of Flash Attention [2]. It is quite fast. On TPUs I've heard that the XLA compiler already emits a flash-attention-like computation graph for a regular JAX implementation of attention so there's no need to have some specialized kernel in that case.