Hacker News new | ask | show | jobs
by lnyan 761 days ago
`import jax.numpy as np`, then we also get a jax implemention after certain modifications: e.g. remove in-place index assignment, replace unsupported functions, etc
2 comments

JAX requires a bit more work to maintain fixed-size buffers as required by XLA, especially in case of caching and rotary embeddings. But yeah, overall the code can be pretty similar [1].

[1]: https://github.com/dfdx/fabrique/blob/main/fabrique/llama/mo...

...which should be much faster also on CPU, I assume.