Hacker News new | ask | show | jobs
by tehf0x 1652 days ago
As a researcher in RL & ML in a big industry lab, I would say most of my colleagues are moving to JAX [https://github.com/google/jax], which this article kind of ignores. JAX is XLA-accelerated NumPy, it's cool beyond just machine learning, but only provides low-level linear algebra abstractions. However you can put something like Haiku [https://github.com/deepmind/dm-haiku] or Flax [https://github.com/google/flax] on top of it and get what the cool kids are using :)
14 comments

> but only provides low-level linear algebra abstractions.

Just to make sure people aren't scared off by this: jax provides a lot more than just low level linear algebra. It has some fundamental NN functions in its lax submodule, and the numpy API itself goes way way past linear algebra. Numpy plus autodiff, plus automatic vectorization, plus automatic parallelization, plus some core NN functions, plus a bunch more stuff.

Jax plus optax (for common optimizers and easily making new optimizers) is plenty sufficient for a lot of NN needs. After that, the other libraries are really just useful for initialization and state management (which is still very useful; I use haiku myself).

Would you mind commenting on Haiku vs Flax? I'm partial to Haiku because I'm a fan of Sonnet/DeepMind, but I've not looked into Flax much!
I haven't used flax, but it seems more like pytorch. I like haiku because it's relatively minimal. The simplest transform does init and that's all. I like that.
Got it. Awesome, thanks for the info I'll check them both out this week
I have a preference for Flax, you basically get Pytorch but streamlined thanks to the Jax foundation.
Awesome, thanks for the suggestion. I'll check out both for sure!
Indeed! Sorry, I was thinking more about the layers but of course JAX is way more than numpy on steroids. (Although it is also that: https://dionhaefner.github.io/2021/12/supercharged-high-reso...). JAX has a very nice vmap for easy parallelization on SIMD accelerators, and pmap even allows cross-device vectorization with a single line which is just beautiful !
What I love about JAX is that it essentially just makes Python into a performant, differentiable programming language.

I'm a pretty big fan of moving away from thinking about ML/Stats/etc specifically and people should more generally embrace the idea of differentiable programming as just a way to program and solve a range of problems.

JAX means that the average python programmer just needs to understand the basics of derivatives and their use (not how to compute them, just what they are and why they're useful) and suddenly has an amazing amount of power they can add to normal code.

The real power of JAX, for me at least, is that you can write the solution to your problem, whatever that problem may be, and use derivatives and gradient descent to find an answer. Sometimes this solution might be essentially a neural network, other times the generalized linear model, but sometimes it might not fit obviously into either of these paradigms.

This isn't quite true. Jax works well for "quasistatic" code, but can't handle more dynamic types of problems (see https://www.stochasticlifestyle.com/useful-algorithms-that-a... for a more detailed explanation).

Jax is definitely the right direction for the python ecosystem, but it can't solve all your problems. At some point you still need a fast language.

Do any JAX experts know if there is an equivalent to https://captum.ai/ - a model interpretability library for pytorch?

In particular i want to be able to measure feature importance on both inputs and internal layers on a sample by sample basis. This is the only thing currently holding me back from using JAX right now.

Alternatively a simle to read/understand/port implementation of DeepLIFT would work too.

thanks

Most? Last I tried JAX it had no real documentation to speak of and all the tutorials you could find on the net were woefully out of date. Even simple toy examples broke with weird error messages. Maybe the situation is better now. I'd rather wait for JAX 2.0 though. :)
Give it another try, I found the docs pretty good, you need to get your head around XLA tracing, and read "the sharp bits" section and you should be pretty set!
Yep and this approach also allows languages like Julia and Elixir to compile their expressions into valid compute graphs that target JAX/XLA. That polyglot capability opens up cutting edge machine learning into quite a bit more ecosystems with another level of capabilities in distribution and fault tolerance as is the case with Elixir + Nx.
Could you or someone elaborate more on how other languages can hook into JAX/XLA?
Julia has XLA.jl [0] which interoperate with their deep-learning stack and Elixir has NX [1] which is higher level (basically JAX but in Elixir). I would love to see someone do something like that in Rust...

[0]: https://github.com/FluxML/XLA.jl

[1]: https://github.com/elixir-nx/nx/tree/main/nx

Is there a straightforward way to move models/pipelines created in JAX to EdgeTPU (TFLite)?
How is it compared with Julia + Flux.ml[1]?

[1] https://fluxml.ai/

The article mentions JAX many times and literally has a flow chart for RL and when to use JAX.
> As a researcher in RL & ML in a big industry lab

Is that big industry lab Google or Deepmind? haha

The article highlights JAX as a framework to watch in several places.
JAX is really exciting! JAX is mentioned in the Research subsection of the "Which should I pick" section. Do you think that the fundamental under-the-hood differences of JAX compared to TensorFlow and PyTorch will affect its adoption?

Haiku is really cool - I haven't used Flax. It'll be really interested to see the development of JAX as time goes on. I also saw some benchmarks that show its neck-and-neck with PyTorch as the fastest of the three, but I think with more optimization its ceiling is higher than PyTorch's.

> Do you think that the fundamental under-the-hood differences of JAX compared to TensorFlow and PyTorch will affect its adoption?

Of course. It's the only library that can be explained from first principles: https://jax.readthedocs.io/en/latest/autodidax.html

Wow that's a really cool resource. Thanks for linking!

Even still, do you think researchers will want to take the time to learn all of that when PyTorch gives them no real reason to switch? Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.

It depends on what you want to do obviously.

pytorch historically hasn't really focused on forward mode auto differentiation: https://github.com/pytorch/pytorch/issues/10223

this definitely limits its generality relative to jax, which makes it less than ideal for anything other than 'typical' deep neural networks

this is especially true when the research in question is related to things like physics or combining physical models and machine learning, which imho is very interesting. those are use cases that pytorch just isn't good at.

Interesting - I didn't realize that it was that important for computational physics. Very cool, I'll have to read up!
> Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.

Every day spent learning JAX is also another day spent not trying to fit a round peg into a square hole of other libraries. I made the leap when I was doing things that were painful in pytorch. In terms of time, I think I came out ahead.

Not everything is a nail, and pytorch is better for some things, an jax is better for others. "Every day spent learning the screwdriver is a day spent not using your hammer."

Totally agree! Always a cost/benefit analysis to consider, so it's nice to hear that it was worth it for someone who made the jump.
> Every day spent learning JAX

To get started JAX is just knowing Python and adding `grad`, `jit` and `vmap` to the mix, it takes about 5 minutes to get going.

To me this is the real power of JAX, it can be viewed as a few functions that make it easy to take any python code you've written and work with derivatives using that. This gives it tremendous flexibility in helping you solve problems.

As an example, I mostly do statistical work with it, rather than NN focused work. It took probably a few minutes to implement a GLM with custom priors over all the parameters, and the use then Hessian for the Laplace approximation of parameter uncertainty. The proper way to solve this would have been using PyMC but this worked good enough for me, and building the model in scratch in JAX took less time than refreshing the PyMC api for me.

The autodidax section of the jax docs is such a wonderful thing. I wish every library had that.
So cool! I'm a bit surprised they took the time to put it together, but I'm definitely not complaining! LOL
Having read some Jax high-performance code, I do like Jax, but it does feel a bit too abstract and low level sometimes. Maybe there aren’t good coding conventions or performance trumped them? Definitely needs improvement on error messages, as well.

For example, a long chain of pmaps, each with some sort of device partitioning logic, not JIT compiling is extremely hard to understand. I basically had to binary search code until the compile errors disappeared.

How performant is JAX vs Numba in terms of non-ML applications?
I'm as well going from TF2 to JAX, it's like TF3 and I hope that google will just put the keras team to work on a JAX version