Hacker News new | ask | show | jobs
by patrickkidger 1212 days ago
It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)

If so, then allow me to make my usual advert here for Equinox:

https://github.com/patrick-kidger/equinox

This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)

On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.

1 comments

As a matter of fact, you’re preaching to the choir! Equinox is my go-to library for jax NN work!