No worries! I didn't mean it as a correction so much as just a discussion; I'm sure it's true that other autodiff systems have very sophisticated automatic remat (like https://openreview.net/forum?id=BkYYXJ9i-). I'm hoping as users push JAX on new applications, especially in simulation and scientific computing, we'll learn a lot and be able to improve!
There's also "cross-country optimization" (https://www-sop.inria.fr/tropics/slides/EdfCea05.pdf) for mixing some forward-mode into reverse-mode to improve memory efficiency. Analogously to jax.checkpoint, we've only experimented with exposing that manually (in jax.jarrett, named because of https://arxiv.org/abs/1810.08297), and even then only for a special case. There's a lot to learn about, experiment with, and build!
There's also "cross-country optimization" (https://www-sop.inria.fr/tropics/slides/EdfCea05.pdf) for mixing some forward-mode into reverse-mode to improve memory efficiency. Analogously to jax.checkpoint, we've only experimented with exposing that manually (in jax.jarrett, named because of https://arxiv.org/abs/1810.08297), and even then only for a special case. There's a lot to learn about, experiment with, and build!