Hacker News new | ask | show | jobs
by 6gvONxR4sf7o 1652 days ago
> but only provides low-level linear algebra abstractions.

Just to make sure people aren't scared off by this: jax provides a lot more than just low level linear algebra. It has some fundamental NN functions in its lax submodule, and the numpy API itself goes way way past linear algebra. Numpy plus autodiff, plus automatic vectorization, plus automatic parallelization, plus some core NN functions, plus a bunch more stuff.

Jax plus optax (for common optimizers and easily making new optimizers) is plenty sufficient for a lot of NN needs. After that, the other libraries are really just useful for initialization and state management (which is still very useful; I use haiku myself).

2 comments

Would you mind commenting on Haiku vs Flax? I'm partial to Haiku because I'm a fan of Sonnet/DeepMind, but I've not looked into Flax much!
I haven't used flax, but it seems more like pytorch. I like haiku because it's relatively minimal. The simplest transform does init and that's all. I like that.
Got it. Awesome, thanks for the info I'll check them both out this week
I have a preference for Flax, you basically get Pytorch but streamlined thanks to the Jax foundation.
Awesome, thanks for the suggestion. I'll check out both for sure!
Indeed! Sorry, I was thinking more about the layers but of course JAX is way more than numpy on steroids. (Although it is also that: https://dionhaefner.github.io/2021/12/supercharged-high-reso...). JAX has a very nice vmap for easy parallelization on SIMD accelerators, and pmap even allows cross-device vectorization with a single line which is just beautiful !