Hacker News new | ask | show | jobs
by jdeaton 499 days ago
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.

Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?

1 comments

I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?

[1] https://github.com/jaymody/picoGPT/blob/main/gpt2.py

yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)