Like all really good ideas, this one seems "obvious" in hindsight. I mean that is a compliment: It would have never occurred to me that transforming code into continuation-passing-style code would allow for automatic differentiation through all dynamic control-flow structures, by leveraging the function-call stack, thus eliminating the need for some kind of "tape" data structure, e.g., as in PyTorch.
My question is about the ongoing work to provide a JIT compiler for Python code. Do you expect it will provide full support for the entire PyTorch and/or Tensorflow APIs?
Thanks! Yes, it wasn't obvious at all when we started looking at AD either.
Lantern supports a good deal of PyTorch (via Snek, our Python front-end similar to AutoGraph) and can also read ONNX. Full feature parity is not our main goal--so far, supported features have been driven mostly by what is required for certain interesting models.
What you call "delimited continuations" sounds a bit like how AD works in Julia's Flux package (and maybe elsewhere): During the forward calculation, a chain of functions is constructed, whose evaluation is the backward pass. This is done by overloading the original * to return both x * y and a closure Δ -> (Δ * y, x * Δ).
Does that sound right, are these indeed similar, or have I mis-understood something? I have never read Scala, but if I squint I can make your overloading of * is doing something similar.
It's different -- instead of returning a closure, we take a closure as additional parameter (check paper for details). This means that the call stack stays intact and all intermediate values can be stack-allocated.
It's specific to CPS form: every basic statement is a full function call, which moves to the next statement by calling another function (i.e. the continuation). Normally this would just be a very quick way to get yourself a stack overflow, so typical CPS-form compilers optimise (or require) the tail-call case, where you can pop the stack frame for the last instruction before moving on to the next. But for AD that's not a big deal, you just re-use those values in the backwards pass.
I doubt it's better than just allocating your own stack on the heap, especially when you have something like a differential equation solver with millions of statements, but it's still a neat trick.
I agree that the performance benefits of such stack allocation (over heap allocation) aren't quite clear in practice.
I feel the bigger win of delimited cont./closure-based AD approaches is that they can model the control flow of reverse-mode AD without AD-specific code transformations. Delimited cont. is especially great at making things modular: each differentiable function performs primal computation, calls the callback with primal result, then performs adjoint computation.
"Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator" https://www.cs.purdue.edu/homes/rompf/papers/wang-preprint20...
The Lantern framework is available here: https://github.com/feiwang3311/Lantern