|
|
|
|
|
by 6gvONxR4sf7o
1652 days ago
|
|
> but only provides low-level linear algebra abstractions. Just to make sure people aren't scared off by this: jax provides a lot more than just low level linear algebra. It has some fundamental NN functions in its lax submodule, and the numpy API itself goes way way past linear algebra. Numpy plus autodiff, plus automatic vectorization, plus automatic parallelization, plus some core NN functions, plus a bunch more stuff. Jax plus optax (for common optimizers and easily making new optimizers) is plenty sufficient for a lot of NN needs. After that, the other libraries are really just useful for initialization and state management (which is still very useful; I use haiku myself). |
|