|
|
|
|
|
by koningrobot
1318 days ago
|
|
It definitely works, JAX only sees the unrolled loop: x = 0
x += y
x += y
x += y
x += y
x += y
return x
The reason you might need `jax.lax.fori_loop` or some such is if you have a long loop with a complex body. Replicating a complex body many times means you end up with a huge computation graph and slow compilation. |
|