|
|
|
|
|
by patrickkidger
1212 days ago
|
|
It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.) If so, then allow me to make my usual advert here for Equinox: https://github.com/patrick-kidger/equinox This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.) On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure. |
|