Hacker News new | ask | show | jobs
by boywitharupee 927 days ago
JAX is a wrapper on top of XLA. Instead of writing pure python, you're writing JAX abstractions.

for ex, a simple loop in JAX:

  def solve(i, v): return i+v
  x = jax.lax.fori_loop(0, 5, solve, 10)