Hacker News new | ask | show | jobs
by SleekEagle 1652 days ago
JAX is really exciting! JAX is mentioned in the Research subsection of the "Which should I pick" section. Do you think that the fundamental under-the-hood differences of JAX compared to TensorFlow and PyTorch will affect its adoption?

Haiku is really cool - I haven't used Flax. It'll be really interested to see the development of JAX as time goes on. I also saw some benchmarks that show its neck-and-neck with PyTorch as the fastest of the three, but I think with more optimization its ceiling is higher than PyTorch's.

1 comments

> Do you think that the fundamental under-the-hood differences of JAX compared to TensorFlow and PyTorch will affect its adoption?

Of course. It's the only library that can be explained from first principles: https://jax.readthedocs.io/en/latest/autodidax.html

Wow that's a really cool resource. Thanks for linking!

Even still, do you think researchers will want to take the time to learn all of that when PyTorch gives them no real reason to switch? Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.

It depends on what you want to do obviously.

pytorch historically hasn't really focused on forward mode auto differentiation: https://github.com/pytorch/pytorch/issues/10223

this definitely limits its generality relative to jax, which makes it less than ideal for anything other than 'typical' deep neural networks

this is especially true when the research in question is related to things like physics or combining physical models and machine learning, which imho is very interesting. those are use cases that pytorch just isn't good at.

Interesting - I didn't realize that it was that important for computational physics. Very cool, I'll have to read up!
> Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.

Every day spent learning JAX is also another day spent not trying to fit a round peg into a square hole of other libraries. I made the leap when I was doing things that were painful in pytorch. In terms of time, I think I came out ahead.

Not everything is a nail, and pytorch is better for some things, an jax is better for others. "Every day spent learning the screwdriver is a day spent not using your hammer."

Totally agree! Always a cost/benefit analysis to consider, so it's nice to hear that it was worth it for someone who made the jump.
> Every day spent learning JAX

To get started JAX is just knowing Python and adding `grad`, `jit` and `vmap` to the mix, it takes about 5 minutes to get going.

To me this is the real power of JAX, it can be viewed as a few functions that make it easy to take any python code you've written and work with derivatives using that. This gives it tremendous flexibility in helping you solve problems.

As an example, I mostly do statistical work with it, rather than NN focused work. It took probably a few minutes to implement a GLM with custom priors over all the parameters, and the use then Hessian for the Laplace approximation of parameter uncertainty. The proper way to solve this would have been using PyMC but this worked good enough for me, and building the model in scratch in JAX took less time than refreshing the PyMC api for me.

The autodidax section of the jax docs is such a wonderful thing. I wish every library had that.
So cool! I'm a bit surprised they took the time to put it together, but I'm definitely not complaining! LOL