Hacker News new | ask | show | jobs
by mattjjatgoogle 2288 days ago
Rematerialization in autodiff is super interesting! XLA does rematerialization optimizations, so you get those automatically under jax.jit. There's also the jax.checkpoint decorator (https://github.com/google/jax/pull/1749) which lets you control reverse-mode checkpointing yourself; you can use it recursively to implement sophisticated checkpointing strategies (see Example 5 in that PR, which is the classic strategy for getting memory cost to scale like log(N) for iteration count N but requiring log(N) times as much computational work). It'd be interesting to experiment with heuristics for deploying those strategies automatically (e.g. given a program in JAX's jaxpr IR) but one of JAX's core philosophies is to keep things explicit and give users control through composable APIs. Automatic heuristics can be built on top.

Another goal is to make JAX a great system for playing with things like this!

1 comments

Thank you for the correction! I should have checked out the software before posting my incorrect surmises from the blog post. It sounds awesome!
No worries! I didn't mean it as a correction so much as just a discussion; I'm sure it's true that other autodiff systems have very sophisticated automatic remat (like https://openreview.net/forum?id=BkYYXJ9i-). I'm hoping as users push JAX on new applications, especially in simulation and scientific computing, we'll learn a lot and be able to improve!

There's also "cross-country optimization" (https://www-sop.inria.fr/tropics/slides/EdfCea05.pdf) for mixing some forward-mode into reverse-mode to improve memory efficiency. Analogously to jax.checkpoint, we've only experimented with exposing that manually (in jax.jarrett, named because of https://arxiv.org/abs/1810.08297), and even then only for a special case. There's a lot to learn about, experiment with, and build!