I've only skimmed through the blog post, but it feels that the GPU acceleration without a need to write any custom code, was the primary reason to choose JAX.
>without a need to write any custom code, was the primary reason to choose JAX.
this is even more "free" in Julia, JAX at least need to worry when foreign call happens (library not derived from Numpy/JAX ecosystem, or outright C/C++ binding without JAX rules).
For someone like me who is familiar with JAX, but only recently starting to consider to pay attention to Julia (it does have a momentum), would it be possible to provide any good examples of using GPU / multiple GPUs from Julia?
I'm not familiar with multi-GPU setup in general. GPU programming in Julia has the advantage that naive operation doesn't even need to be GPU-aware (for writers), since GPU arrays (of any vendor backend) conforms the AbstractArray interface.
this is even more "free" in Julia, JAX at least need to worry when foreign call happens (library not derived from Numpy/JAX ecosystem, or outright C/C++ binding without JAX rules).