`import jax.numpy as np`, then we also get a jax implemention after certain modifications: e.g. remove in-place index assignment, replace unsupported functions, etc
JAX requires a bit more work to maintain fixed-size buffers as required by XLA, especially in case of caching and rotary embeddings. But yeah, overall the code can be pretty similar [1].
[1]: https://github.com/dfdx/fabrique/blob/main/fabrique/llama/mo...