Hacker News new | ask | show | jobs
by cygaril 2014 days ago
Seems to have missed the existence of jax.jit, which basically constructs an XLA program (call it a graph if you like) from your Python function which can then be optimized.
2 comments

In the section title, JAX:

> But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels...

The authors gives that quote (from the JAX documentation) but does not seem to interiorize it as his conclusion says:

> This is the niche that Theano (or rather, Theano-PyMC/Aesara) fills that other contemporary tensor computation libraries do not: the promise is that if you take the time to specify your computation up front and all at once, Theano can optimize the living daylight out of your computation - whether by graph manipulation, efficient compilation or something else entirely - and that this is something you would only need to do once.

It is exactly what JAX does. There is a computational graph in JAX (its encoded in XLA and specified with their numpy like syntax), it is build once, optimized and then runs on the GPU.

TorchScript JIT (torch.jit.script) is similar for PyTorch.
Not even cloese, jax.jit allow you to compute almost anything using lax.for_loops, lax.cond and other lax and jax contsturts pytorch jit does not allow that its just extra optimization for static pytorch functions.
No autodiff for most of these though.
JAX autograd will work on most any jitted fn - the control-flow limitations are no autograd for code with for/while loops since there's a statically unknowable trip count through the loop body. Much looping code can be handled differentiably using a "scan" though.