Seems to have missed the existence of jax.jit, which basically constructs an XLA program (call it a graph if you like) from your Python function which can then be optimized.
The authors gives that quote (from the JAX documentation) but does not seem to interiorize it as his conclusion says:
> This is the niche that Theano (or rather, Theano-PyMC/Aesara) fills that other contemporary tensor computation libraries do not: the promise is that if you take the time to specify your computation up front and all at once, Theano can optimize the living daylight out of your computation - whether by graph manipulation, efficient compilation or something else entirely - and that this is something you would only need to do once.
It is exactly what JAX does. There is a computational graph in JAX (its encoded in XLA and specified with their numpy like syntax), it is build once, optimized and then runs on the GPU.
Not even cloese, jax.jit allow you to compute almost anything using lax.for_loops, lax.cond and other lax and jax contsturts pytorch jit does not allow that its just extra optimization for static pytorch functions.
JAX autograd will work on most any jitted fn - the control-flow limitations are no autograd for code with for/while loops since there's a statically unknowable trip count through the loop body. Much looping code can be handled differentiably using a "scan" though.
> But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels...