Hacker News new | ask | show | jobs
by tasubotadas 987 days ago
Is there any benefit using it instead of pytorch?
3 comments

Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily.
The main benefit in my experience is that it’s much easier to do distributed computations in JAX. It has a much nicer API. For single device computing there’s no advantage either way.
If you like functional languages, then Jax will fit better for you. It provides a bunch of function transformations to implement eg grad, JIT etc.