|
|
|
|
|
by learndeeply
1330 days ago
|
|
JAX is a DSL on top of XLA, instead of writing Python. Example: a JAX for loop looks like this: def summ(i, v): return i + v
x = jax.lax.fori_loop(0, 100, summ, 5)
A for loop in TinyGrad or PyTorch looks like regular Python: x = 5
for i in range(0, 100):
x += 1
By the way, PyTorch also has JIT. |
|