Have you tried numba+numpy? In my experience, it is much faster than Jax and can compile to cuda. It's not caveat free, but it also removes the hustle of labeling arrays as donated in Jax.
Have you been successful in implementing non-trivial computational code in numba/numpy? I've always found it starts to really break for anything which isn't really trivial, and the errors are mostly non-prescriptive and highly verbose.
I just implemented both a CSV parser and an address standardizer in numba (both CPU and GPU) running in parallel feed through a message queue with a bunch of workers subprocs.
It takes a bit of getting used to but the performance gains on impressive. Basically, my bottlenecks shift from compute to i/o.
I think you have to balance it against writing in C/C++. Mentally, it is basically the same work as writing in C (you manage memory/you write complicated for-loops) but you have good array support with numpy. The primary advantage for me that everything stays in the python runtime environment. You just run the code without any extra steps.
...
What is missing from the timing type 'toy' benchmarks is an understanding that there is typically more than one bottleneck in a real problem and it is easy to choose the wrong one to optimize and get little gains.
After starting C (30 years ago now), spending a long time in C#, then switching to Python a few years ago, I think the unappreciated advantage of python is that I have to abandon all pretense of caring about speed and just get stuff working. It basically solves the pre-mature optimization problem for me by being a fast interpreted language rather than a slow compiled language.
> I think the unappreciated advantage of python is that I have to abandon all pretense of caring about speed and just get stuff working. It basically solves the pre-mature optimization problem for me
I feel the same way. With Python I just write the simplest algorithm that first comes to my mind, even though I know that it is not the most optimized way of doing things. But most of the time I am surprised that it works so fast that I realize I actually don't need to optimize it.
And being able to create and easily manipulate dictionaries and tuples also allows me to create efficient data structures very quickly.
We have a very large portion of production code written in numba we’ve been running for about 3 years and I made a small contribution to the library. There are a lot of gotchas to numba when the codebase gets large but the benefits far outweigh the downsides. I highly recommend numba.
Edit: also note that a big part of the umap library is written in numba.
I still have some issues in nojit mode, but jit is most of the time fine. It's still a bit iffy with lists, but most of the time I can use numpy arrays.
TBF I am mainly using it for mostly pure path functions.
I have a 25,000+ loc scientific codebase in Python, with lot of use in my specialized domain. I gave up on numba, because it's speedups vanish when transitioning to real problems. And secondly, it's bugs are not that intuitive.
I had the same experience about 2 years ago. Maybe it's changed since then? It was nice when you had a pure math function to write, but otherwise seemed to be unreliable, especially in multithreaded and multiprocess situations.
I haven't tried numba but I've heard good things about it! Nice linked tutorial. If I understand correctly, you pass in jitted functions (using numba, and jax) into iminuit which does the optimisation?
With Jax you can write native for loops that can also be jitted (I imagine you can also do this in numba?); this can then be really fast. Though in that case you would have to write the optimisation algorithm yourself which is not always practical!
Another big speedup in Jax is due to vmap/pmap, which allow to vectorise/parallelise computation. For example you can build a massive gram matrix really quickly using vmap.
Another point: Jax can also run on GPU (like numba :) ) without having to rewrite anything.
> I haven't tried numba but I've heard good things about it! Nice linked tutorial. If I understand correctly, you pass in jitted functions (using numba, and jax) into iminuit which does the optimisation?
Yes, you can just pass the function and it runs the optimization, in severe cases, you can start by doing a grid `.scan` or `.simplex` (Nelder-Mead simplex method), then `migrad` to minimize, and `.hesse` for 1 sigma bound.
You can also provide a gradient function that, well computes the gradient instead of computing it numerically.
> I imagine you can also do this in numba
Yes! It compiles to native code.
> Another big speedup in Jax is due to vmap/pmap, which allow to vectorise/parallelise computation.
It's possible to compile some vmapped functions with numba too, if the link is any indication, you may see even greater speedup than just jax.grad
However, I do concur that jax's vmap is absolutely fantastic and I found it very useful on many occasions.