|
|
|
|
|
by jeremiecoullon
1884 days ago
|
|
I haven't tried numba but I've heard good things about it! Nice linked tutorial. If I understand correctly, you pass in jitted functions (using numba, and jax) into iminuit which does the optimisation? With Jax you can write native for loops that can also be jitted (I imagine you can also do this in numba?); this can then be really fast. Though in that case you would have to write the optimisation algorithm yourself which is not always practical! Another big speedup in Jax is due to vmap/pmap, which allow to vectorise/parallelise computation. For example you can build a massive gram matrix really quickly using vmap. Another point: Jax can also run on GPU (like numba :) ) without having to rewrite anything. |
|
Yes, you can just pass the function and it runs the optimization, in severe cases, you can start by doing a grid `.scan` or `.simplex` (Nelder-Mead simplex method), then `migrad` to minimize, and `.hesse` for 1 sigma bound.
You can also provide a gradient function that, well computes the gradient instead of computing it numerically.
> I imagine you can also do this in numba
Yes! It compiles to native code.
> Another big speedup in Jax is due to vmap/pmap, which allow to vectorise/parallelise computation.
It's possible to compile some vmapped functions with numba too, if the link is any indication, you may see even greater speedup than just jax.grad
However, I do concur that jax's vmap is absolutely fantastic and I found it very useful on many occasions.