|
|
|
|
|
by mattjjatgoogle
2288 days ago
|
|
Rematerialization in autodiff is super interesting! XLA does rematerialization optimizations, so you get those automatically under jax.jit. There's also the jax.checkpoint decorator (https://github.com/google/jax/pull/1749) which lets you control reverse-mode checkpointing yourself; you can use it recursively to implement sophisticated checkpointing strategies (see Example 5 in that PR, which is the classic strategy for getting memory cost to scale like log(N) for iteration count N but requiring log(N) times as much computational work). It'd be interesting to experiment with heuristics for deploying those strategies automatically (e.g. given a program in JAX's jaxpr IR) but one of JAX's core philosophies is to keep things explicit and give users control through composable APIs. Automatic heuristics can be built on top. Another goal is to make JAX a great system for playing with things like this! |
|