|
|
|
|
|
by eterevsky
1327 days ago
|
|
I've just tried making a loop in a jit-compiled function and it just worked: >>> import jax
>>> def a(y):
... x = 0
... for i in range(5):
... x += y
... return x
...
>>> a(5)
25
>>> a_jit = jax.jit(a)
>>> a_jit(5)
DeviceArray(25, dtype=int32, weak_type=True)
|
|