Hacker News new | ask | show | jobs
by komuher 2012 days ago
Not even cloese, jax.jit allow you to compute almost anything using lax.for_loops, lax.cond and other lax and jax contsturts pytorch jit does not allow that its just extra optimization for static pytorch functions.
1 comments

No autodiff for most of these though.
JAX autograd will work on most any jitted fn - the control-flow limitations are no autograd for code with for/while loops since there's a statically unknowable trip count through the loop body. Much looping code can be handled differentiably using a "scan" though.