|
|
|
|
|
by 6gvONxR4sf7o
1216 days ago
|
|
Thanks! One last thing, since I have your ear. The function transformation aspects of jax seem to make their way into downstream libraries like haiku, resulting in a lot of "magic" that can be difficult to examine and debug. Are there any utils you made to make jax's own transformations more transparent, which you think might be helpful to third party transformations? Higher order functions are difficult in general, and it would be fantastic to have core patterns or tools for breaking them open. |
|
If so, then allow me to make my usual advert here for Equinox:
https://github.com/patrick-kidger/equinox
This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)
On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.