Hacker News new | ask | show | jobs
by geoalchimista 1660 days ago
What made you choose JAX over Julia? I'm interested in this question, because I have been thinking about transitioning to Julia but have always hesitated to make the move, since overall the Python ecosystem still seems way ahead in terms of visualization and toolchain.

Also, would you expect JAX acceleration to work well with other types of discretization, such as spectral methods?

2 comments

If numpy is a good fit for it, JAX is a good fit for accelerating it, basically. I think of it as numpy plus program transformations, such as differentiation, JIT, parallelization, compiling to XLA, compiling to TPU/GPU, etc.

The magic of JAX is that it keeps all that stuff about as simple as writing numpy code.

From the post:

> JAX on GPU outperforms everything

I've only skimmed through the blog post, but it feels that the GPU acceleration without a need to write any custom code, was the primary reason to choose JAX.

>without a need to write any custom code, was the primary reason to choose JAX.

this is even more "free" in Julia, JAX at least need to worry when foreign call happens (library not derived from Numpy/JAX ecosystem, or outright C/C++ binding without JAX rules).

For someone like me who is familiar with JAX, but only recently starting to consider to pay attention to Julia (it does have a momentum), would it be possible to provide any good examples of using GPU / multiple GPUs from Julia?

I've tried to search on my own, but only a way to write CUDA-dependent code: https://juliagpu.gitlab.io/CUDA.jl/usage/multigpu/

I'm not familiar with multi-GPU setup in general. GPU programming in Julia has the advantage that naive operation doesn't even need to be GPU-aware (for writers), since GPU arrays (of any vendor backend) conforms the AbstractArray interface.

If you're advanced library writer, you can leverage: https://juliagpu.github.io/KernelAbstractions.jl/stable/#Wri... which allows you to write kernel, in Julia, that compiles efficiently with rest of native Julia code, and that works cross-vendor!

Back to multi GPU, it seems there's: https://clima.github.io/OceananigansDocumentation/stable/app... which is MPI based?

https://www.juliapackages.com/p/pencilarrays is a really good tool to do this type of stuff automatically for some applications.