Hacker News new | ask | show | jobs
by krasin 1657 days ago
From the post:

> JAX on GPU outperforms everything

I've only skimmed through the blog post, but it feels that the GPU acceleration without a need to write any custom code, was the primary reason to choose JAX.

1 comments

>without a need to write any custom code, was the primary reason to choose JAX.

this is even more "free" in Julia, JAX at least need to worry when foreign call happens (library not derived from Numpy/JAX ecosystem, or outright C/C++ binding without JAX rules).

For someone like me who is familiar with JAX, but only recently starting to consider to pay attention to Julia (it does have a momentum), would it be possible to provide any good examples of using GPU / multiple GPUs from Julia?

I've tried to search on my own, but only a way to write CUDA-dependent code: https://juliagpu.gitlab.io/CUDA.jl/usage/multigpu/

I'm not familiar with multi-GPU setup in general. GPU programming in Julia has the advantage that naive operation doesn't even need to be GPU-aware (for writers), since GPU arrays (of any vendor backend) conforms the AbstractArray interface.

If you're advanced library writer, you can leverage: https://juliagpu.github.io/KernelAbstractions.jl/stable/#Wri... which allows you to write kernel, in Julia, that compiles efficiently with rest of native Julia code, and that works cross-vendor!

Back to multi GPU, it seems there's: https://clima.github.io/OceananigansDocumentation/stable/app... which is MPI based?

https://www.juliapackages.com/p/pencilarrays is a really good tool to do this type of stuff automatically for some applications.