As a researcher in RL & ML in a big industry lab, I would say most of my colleagues are moving to JAX [https://github.com/google/jax], which this article kind of ignores. JAX is XLA-accelerated NumPy, it's cool beyond just machine learning, but only provides low-level linear algebra abstractions. However you can put something like Haiku [https://github.com/deepmind/dm-haiku] or Flax [https://github.com/google/flax] on top of it and get what the cool kids are using :)
> but only provides low-level linear algebra abstractions.
Just to make sure people aren't scared off by this: jax provides a lot more than just low level linear algebra. It has some fundamental NN functions in its lax submodule, and the numpy API itself goes way way past linear algebra. Numpy plus autodiff, plus automatic vectorization, plus automatic parallelization, plus some core NN functions, plus a bunch more stuff.
Jax plus optax (for common optimizers and easily making new optimizers) is plenty sufficient for a lot of NN needs. After that, the other libraries are really just useful for initialization and state management (which is still very useful; I use haiku myself).
I haven't used flax, but it seems more like pytorch. I like haiku because it's relatively minimal. The simplest transform does init and that's all. I like that.
Indeed! Sorry, I was thinking more about the layers but of course JAX is way more than numpy on steroids. (Although it is also that: https://dionhaefner.github.io/2021/12/supercharged-high-reso...). JAX has a very nice vmap for easy parallelization on SIMD accelerators, and pmap even allows cross-device vectorization with a single line which is just beautiful !
What I love about JAX is that it essentially just makes Python into a performant, differentiable programming language.
I'm a pretty big fan of moving away from thinking about ML/Stats/etc specifically and people should more generally embrace the idea of differentiable programming as just a way to program and solve a range of problems.
JAX means that the average python programmer just needs to understand the basics of derivatives and their use (not how to compute them, just what they are and why they're useful) and suddenly has an amazing amount of power they can add to normal code.
The real power of JAX, for me at least, is that you can write the solution to your problem, whatever that problem may be, and use derivatives and gradient descent to find an answer. Sometimes this solution might be essentially a neural network, other times the generalized linear model, but sometimes it might not fit obviously into either of these paradigms.
Do any JAX experts know if there is an equivalent to https://captum.ai/ - a model interpretability library for pytorch?
In particular i want to be able to measure feature importance on both inputs and internal layers on a sample by sample basis. This is the only thing currently holding me back from using JAX right now.
Alternatively a simle to read/understand/port implementation of DeepLIFT would work too.
Most? Last I tried JAX it had no real documentation to speak of and all the tutorials you could find on the net were woefully out of date. Even simple toy examples broke with weird error messages. Maybe the situation is better now. I'd rather wait for JAX 2.0 though. :)
Give it another try, I found the docs pretty good, you need to get your head around XLA tracing, and read "the sharp bits" section and you should be pretty set!
Yep and this approach also allows languages like Julia and Elixir to compile their expressions into valid compute graphs that target JAX/XLA. That polyglot capability opens up cutting edge machine learning into quite a bit more ecosystems with another level of capabilities in distribution and fault tolerance as is the case with Elixir + Nx.
Julia has XLA.jl [0] which interoperate with their deep-learning stack and Elixir has NX [1] which is higher level (basically JAX but in Elixir). I would love to see someone do something like that in Rust...
JAX is really exciting! JAX is mentioned in the Research subsection of the "Which should I pick" section. Do you think that the fundamental under-the-hood differences of JAX compared to TensorFlow and PyTorch will affect its adoption?
Haiku is really cool - I haven't used Flax. It'll be really interested to see the development of JAX as time goes on. I also saw some benchmarks that show its neck-and-neck with PyTorch as the fastest of the three, but I think with more optimization its ceiling is higher than PyTorch's.
Wow that's a really cool resource. Thanks for linking!
Even still, do you think researchers will want to take the time to learn all of that when PyTorch gives them no real reason to switch? Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.
this definitely limits its generality relative to jax, which makes it less than ideal for anything other than 'typical' deep neural networks
this is especially true when the research in question is related to things like physics or combining physical models and machine learning, which imho is very interesting. those are use cases that pytorch just isn't good at.
> Every day spent learning JAX is another day spent not reviewing literature, writing papers, or developing new models.
Every day spent learning JAX is also another day spent not trying to fit a round peg into a square hole of other libraries. I made the leap when I was doing things that were painful in pytorch. In terms of time, I think I came out ahead.
Not everything is a nail, and pytorch is better for some things, an jax is better for others. "Every day spent learning the screwdriver is a day spent not using your hammer."
To get started JAX is just knowing Python and adding `grad`, `jit` and `vmap` to the mix, it takes about 5 minutes to get going.
To me this is the real power of JAX, it can be viewed as a few functions that make it easy to take any python code you've written and work with derivatives using that. This gives it tremendous flexibility in helping you solve problems.
As an example, I mostly do statistical work with it, rather than NN focused work. It took probably a few minutes to implement a GLM with custom priors over all the parameters, and the use then Hessian for the Laplace approximation of parameter uncertainty. The proper way to solve this would have been using PyMC but this worked good enough for me, and building the model in scratch in JAX took less time than refreshing the PyMC api for me.
Having read some Jax high-performance code, I do like Jax, but it does feel a bit too abstract and low level sometimes. Maybe there aren’t good coding conventions or performance trumped them? Definitely needs improvement on error messages, as well.
For example, a long chain of pmaps, each with some sort of device partitioning logic, not JIT compiling is extremely hard to understand. I basically had to binary search code until the compile errors disappeared.
Just to make sure people aren't scared off by this: jax provides a lot more than just low level linear algebra. It has some fundamental NN functions in its lax submodule, and the numpy API itself goes way way past linear algebra. Numpy plus autodiff, plus automatic vectorization, plus automatic parallelization, plus some core NN functions, plus a bunch more stuff.
Jax plus optax (for common optimizers and easily making new optimizers) is plenty sufficient for a lot of NN needs. After that, the other libraries are really just useful for initialization and state management (which is still very useful; I use haiku myself).