Hacker News new | ask | show | jobs
by cjv 1787 days ago
...doesn't the JAX example just need the argument set to static_argnums and then it will work?
1 comments

static_argnums is really just a way to give a bit more assumptions to attempt to build a quasi-static code even if it's using dynamic constructs. In this example that will force it to trace one only one of the two branches (depending on whichever static_argnums sends it down). That is going to generate incorrect code for input values which should've traced the other branch (so the real solution of `lax.cond` is to always trace and always compute both branches, as mentioned in the post). If the computation is actually not quasi-static, there's no good choice for a static argnum. See the factorial example.
Ah, thanks for the explanation.