Hacker News new | ask | show | jobs
Supercharged high-resolution ocean simulation with Jax (dionhaefner.github.io)
90 points by dionhaefner 1660 days ago
5 comments

Awesome! One question immediately comes to mind. Any interest in doing this stuff with Julia? You're basically the epitome of their target audience: a scientific computing expert who does HPC with differentiable programs.
A bit late to the party, but here are some reasons:

- When we started Veros (~4 years ago) Julia was very new on our radar and we didn't know whether it would stick. And to be frank, I'm still not convinced whether it will stick. Yes it seems like a fantastic language, but we all know how long it took Python to gain traction.

- Climate scientists and students already do their post-processing in Python. Having the whole stack in the same language makes things a lot easier for domain experts whose first priority is physics, not coding.

- Python skills translate better to other jobs, which I think is important for young academics.

- The Python library ecosystem is so good. Need to use PETSc? `import petsc4py`. Simplify postprocessing? Export your model state as `xarray` dataset. Julia is great for bleeding edge autodiff through everything stuff, but the bread and butter libraries are just so polished and battle tested in Python.

- I don't know Julia :)

Those are very good reasons!
There's an earlier blog post by the same author where they discuss three possible ways of moving away from the Fortran/C status quo towards higher-level models. They mention Julia as one of the routes, but not the one they decided to take: https://dionhaefner.github.io/2021/04/higher-level-geophysic...
"On the other hand, Julia’s focus on scientific applications is both blessing and curse. In this day and age, a lot of the progress in computing is driven by applications outside academia (mostly through machine learning)." This seems like a crazy mis-read to me. Julia is probably the language that has the best integration of differential equations and machine learning. Jax closes the gap a little, but is still way behind.

For example https://gist.github.com/ChrisRackauckas/62a063f23cccf3a55a4a... shows a pretty simple case where DifferentialEquations.JL is 6x faster at gradient calculations than Jax.

I was mostly referring to the millions (billions?) of dollars getting poured into Python library development by tech companies. With the effect that Python stays relevant and has a thriving library ecosystem. Maybe I'm wrong and Julia is just that good that it doesn't matter - I guess time will tell.
Jax is just a tool to generate XLA, which produces extremely high performance computational graphs which can map to arbitrarily fast hardware, so I'm very skeptical of the utility of the conclusions of thelink you provided (which seems to be comparing single process CPU linear algebra?)
Single thread CPU Linear algebra is the bottleneck of most small systems, so if you can't do that right, you are going to have problems. If you don't believe the benchmark, feel free to run them yourself.

That said, Jax also has bigger issues in it's handling of higher derivatives. Currently, it only supports a few types of jacobians, and the ones it is missing include all the sparse methods that can make your code orders of magnitude faster. https://jax.readthedocs.io/en/latest/notebooks/autodiff_cook.... DifferentialEquations, on the other hand can do automatic sparsity detection https://diffeq.sciml.ai/stable/tutorials/advanced_ode_exampl....

This post is about Big Simulations, not small systems. Like, hundreds to thousands of cores wiht parameters that don't fit in RAM on a single machine.

I am sure the benchmark produces the numbers the author says, but it's not measuring something useful to the posters of this simulation.

XLA only optimizes quasi-static code, which does not include adaptive numerical solvers like those for ODEs. It's a generally good assumption for ML though, but there are ways to break it. I wrote a piece showcasing some ideas around that: https://www.stochasticlifestyle.com/useful-algorithms-that-a...
IIUC people have already run MD (which is the field I used to work in) on XLA, https://twitter.com/sschoenholz/status/1334997741185814530 In these cases it's almost always better (unless you are a numerical genius) to port to the engine, than to try to make a better algorithm that runs on a smaller engine.
I'm also surprised that XLA.jl doesn't seem to have had continued development: https://github.com/FluxML/XLA.jl

When in doubt, piggybacking on (or at least interoperating with) what the large technology companies are investing in is probably savvy, sort of what the OP did.

XLA.jl was kind of a solution looking for a problem. If you want fast code in Julia, you can just write Julia.
I read that as being about what language industry uses to write ML applications, not about technical feasibility of integrating machine learning methods into a codebase. Put differently: industry most often uses Python (especially in ML), therefore the author wants to target Python in order to maximize uptake outside of academia. They even admit that doing it in Python is technically harder than doing it in Julia ("Unfortunately, this type [Type III] is also the hardest to get right"), but consider it worth the trouble for the broader accessibility.

(That's more or less the direction I've been going with research code lately too, so I can sympathize, although I'm not entirely happy with the situation and definitely also sympathize with the Julia folks being unhappy about it.)

Somehow I don't think an ocean simulation needs to be in Python so some startup can use it to... what, sell ads or something?

Anyone interesting enough to be looking at your ocean simulation code can probably handle it being in Julia, and may even prefer it, since the language is so much better designed for this kind of thing than Python.

hopefully if enough people are unhappy about it && sees future in alternative (i.e. critical mass), we can collectively have a "phase transition".
It may take a while, however. 15-20 years ago, you kind of had to use Python on the sly in the scientific setting vs the incumbents (MATLAB, C++, Fortran). Julia seems to be in a similar phase.

That being said, Python does have some structural advantages since it positions itself as a universal glue. It's much easier to gain a critical mass in that regard vs a niche area like scientific or numerical computing. That being said, Julia is probably underrated in general purpose usage.

That's an old example. It will now default to Enzyme and should do quite a bit faster. I should update that.
In a shameless plug, I want to note that running these sorts of workloads on CPU using Pytorch got much faster (some results on a benchmark from this post’s author’s suite in [0]) in the most recent torch release thanks to the addition of a JIT compiler. Obviously there’s much to recommend Jax (the XLA compiler is quite excellent), but it’s nice to have some choice in the space.

[0] https://www.linkedin.com/feed/update/activity:68640106214579...

True, but unfortunately Pytorch is not quite there yet when it comes to more complex benchmarks:

https://github.com/dionhaefner/pyhpc-benchmarks#example-resu...

JAX really is the only library that comes close to low-level code on CPU, almost always (that I've tried).

Interesting, I thought pytorch was a bit more competitive on those other benchmarks (but admittedly it’s been a while since I looked). Slicing shouldn’t be a fundamental problem, but perhaps there are some important details that have been overlooked. Thanks for pointing it out!
What made you choose JAX over Julia? I'm interested in this question, because I have been thinking about transitioning to Julia but have always hesitated to make the move, since overall the Python ecosystem still seems way ahead in terms of visualization and toolchain.

Also, would you expect JAX acceleration to work well with other types of discretization, such as spectral methods?

If numpy is a good fit for it, JAX is a good fit for accelerating it, basically. I think of it as numpy plus program transformations, such as differentiation, JIT, parallelization, compiling to XLA, compiling to TPU/GPU, etc.

The magic of JAX is that it keeps all that stuff about as simple as writing numpy code.

From the post:

> JAX on GPU outperforms everything

I've only skimmed through the blog post, but it feels that the GPU acceleration without a need to write any custom code, was the primary reason to choose JAX.

>without a need to write any custom code, was the primary reason to choose JAX.

this is even more "free" in Julia, JAX at least need to worry when foreign call happens (library not derived from Numpy/JAX ecosystem, or outright C/C++ binding without JAX rules).

For someone like me who is familiar with JAX, but only recently starting to consider to pay attention to Julia (it does have a momentum), would it be possible to provide any good examples of using GPU / multiple GPUs from Julia?

I've tried to search on my own, but only a way to write CUDA-dependent code: https://juliagpu.gitlab.io/CUDA.jl/usage/multigpu/

I'm not familiar with multi-GPU setup in general. GPU programming in Julia has the advantage that naive operation doesn't even need to be GPU-aware (for writers), since GPU arrays (of any vendor backend) conforms the AbstractArray interface.

If you're advanced library writer, you can leverage: https://juliagpu.github.io/KernelAbstractions.jl/stable/#Wri... which allows you to write kernel, in Julia, that compiles efficiently with rest of native Julia code, and that works cross-vendor!

Back to multi GPU, it seems there's: https://clima.github.io/OceananigansDocumentation/stable/app... which is MPI based?

https://www.juliapackages.com/p/pencilarrays is a really good tool to do this type of stuff automatically for some applications.
My question is how much of the operations in JAX here can be done with reduced precision and can utilize training accelerators i.e. TPUs. I've noticed a lot of research coming out in physics, where everything is simulated in at least double float, being augmented with ML approaches where precision is traded for dynamic range.
The thing with reduced precision is that things may look fine at first, but then you eventually notice unphysical features in your solution (like additional wave modes after very long simulation times, or energy conservation issues). So we really don't know as a community yet how far we can venture from float64, but it looks like float32 may be viable.

Veros works OK on TPUs (about the same speed as a high-end GPU), but since you can't buy TPUs that's an immediate no for most academic users of climate models. Renting hardware doesn't really make sense when you keep it busy for months at a time and the HPC infrastructure is already in place.

can't you fix a lot of the nonphysical issues by using better integration schemes? that might be hard in Jax though. From what I know, it's options for better numerical stability are pretty limited.
No, in fact, you want to go lower order with lower precision. The real answer is that if the solution is in the chaotic regime then maybe Float16 is fine because you'll be dominated by other numerical errors anyways (if you're also making sure you have adequate conservation so the solution doesn't explode in some way), but if you're not in the chaotic regime then even Float32 is pushing it in many cases (i.e. it better be non-stiff as stiffness pretty much guernetees operations which span beyond Float32 relative epsilon). So it's a case-dependent topic and not something that has an easy answer, though the case for Float16 is rather small.

(We had some small tests generating TPU ODE solver code from Julia and showcased some rather bizarre stuff back when Keno was working on it, but never wrote a post summarizing all of it)

I would recommend checking out https://www.youtube.com/watch?v=GiSsoA1udUk. It shows that you can can do climate models with 16 bit numbers.
like some other commenters here, https://github.com/CliMA/Oceananigans.jl immediately comes to mind, maybe it would be fun to compare projects on this scale between JAX/Julia.

> JAX offers more than just a JIT compiler: JAX functions are also differentiable

if the downstream library is completely implemented in JAX (numba) ecosystem. Similar for Julia, except implementing fast code in Julia is natural, doesn't involve debugging 3 compilers (Cpython, Numba, Jax). Many python library is only differentiable because the 100x more effort were put in writing C/C++ backend, binding to python, and writing chain rules for foreign functions.

I would imagine Julia to be a good fit for this direction in the future!

Oh another thing, in the chaotic regime which is of interest, standard automatic differentiation schemes don't even apply as you require shadow adjoints given the shadow trajectory leads to inaccurate calculations for the gradient. Julia's system is the only one that I know of that has shadow adjoints for differentiation of ergodic properties.

https://frankschae.github.io/post/shadowing/

So unless the purpose is to only differentiate the simulator for short time periods or in the absence of chaos, I cannot see differentiation as a good justification because AD will not give a stable algorithm on that type of problem.

> https://github.com/CliMA/Oceananigans.jl

Off-topic, but I have to say that is one of my favourite package names in Julia. (A more recent one is [MATDaemon](https://github.com/jondeuce/MATDaemon.jl))

The real problem with the Jax code is that the non-composable programming language setup put it into a corner where it's using an extremely inefficient time stepping method that it has "optimized", but how is it optimized if you're doing 100 times more function calls than you have to? Algorithms matter, and "optimizing Adams-Bashforth 2" is a pretty silly idea.
I agree with your point regarding non-composability and ”algorithm lock in” (which may or may not be solvable woth better abstractions), but explicit time stepping schemes are still the main workhorse of global ocean modelling, so I’m not sure whether ”silly” is the right label here.
Why are explicit time stepping schemes the main tool used? Is it because the languages that these models are written in aren't flexible enough, or is there a math reason why dynamic time-stepping isn't better?
Climate models are vastly complex, and you need to bring together many experts from many disciplines to write and maintain one, and analyze the output. This seems to lead to the simplest methods coming out on top. Perhaps it could be solved with better abstractions (a lot of very smart people are trying).
That's precisely what composability solves. We're seeing in CLIMA that using more general highly optimized solvers can greatly decrease the `f` cost count moreso than focusing on really low level optimizations. Especially in things like the land model where you can have many stability issues (such as large complex eigenvalues which happen to work very poorly with multistep methods, even BDF), the ability to split the develop of the time stepping to a huge community of 100's of developers without losing performance gives something where more optimal methods for a domain arise and are found. Yes, the standard is to use something simpler. No, it's not even close to optimal and that is something that is being made very clear.