Hacker News new | ask | show | jobs
by 6gvONxR4sf7o 1657 days ago
If numpy is a good fit for it, JAX is a good fit for accelerating it, basically. I think of it as numpy plus program transformations, such as differentiation, JIT, parallelization, compiling to XLA, compiling to TPU/GPU, etc.

The magic of JAX is that it keeps all that stuff about as simple as writing numpy code.