Hacker News new | ask | show | jobs
by dekhn 1657 days ago
Jax is just a tool to generate XLA, which produces extremely high performance computational graphs which can map to arbitrarily fast hardware, so I'm very skeptical of the utility of the conclusions of thelink you provided (which seems to be comparing single process CPU linear algebra?)
3 comments

Single thread CPU Linear algebra is the bottleneck of most small systems, so if you can't do that right, you are going to have problems. If you don't believe the benchmark, feel free to run them yourself.

That said, Jax also has bigger issues in it's handling of higher derivatives. Currently, it only supports a few types of jacobians, and the ones it is missing include all the sparse methods that can make your code orders of magnitude faster. https://jax.readthedocs.io/en/latest/notebooks/autodiff_cook.... DifferentialEquations, on the other hand can do automatic sparsity detection https://diffeq.sciml.ai/stable/tutorials/advanced_ode_exampl....

This post is about Big Simulations, not small systems. Like, hundreds to thousands of cores wiht parameters that don't fit in RAM on a single machine.

I am sure the benchmark produces the numbers the author says, but it's not measuring something useful to the posters of this simulation.

XLA only optimizes quasi-static code, which does not include adaptive numerical solvers like those for ODEs. It's a generally good assumption for ML though, but there are ways to break it. I wrote a piece showcasing some ideas around that: https://www.stochasticlifestyle.com/useful-algorithms-that-a...
IIUC people have already run MD (which is the field I used to work in) on XLA, https://twitter.com/sschoenholz/status/1334997741185814530 In these cases it's almost always better (unless you are a numerical genius) to port to the engine, than to try to make a better algorithm that runs on a smaller engine.
Yes, that has nothing to do with what I just said though. Of course MD is fine because symplectic ODE solvers cannot generally have adaptivity (without tricky and very expensive handling of `t` inside of the Hamiltonian which nobody does because it's still an active research topic how to make it computationally viable). So MD gets a quasi-static code which XLA is fine with optimizing. I was explicitly talking about the non-quasi-static cases.
I've worked in ODEs for 20+ years and I don't think that non-quasi-static solvers have really ever come up. Are these commonly used? IE, how much CPU/GPU/TPU time is spent on them globally and how useful are they?
Have you used almost any ODE solver? Almost every single one uses embedded methods to adapt time steps. ode23, ode45, ode23t, ode23tb, ode15s, LSODE, LSODA, radau, rodas, VODE, CVODE, ... even for DAEs you have DASSL, IDA, ... I can keep going but it's just listing every ODE solver code out there. Once you do that then the computation is dependent on values and thus the full compute is not determined by the input sizes, which is something known to be blocking the full usage in Jax because of XLA limitations (for example the implementation of dense output).
I'm also surprised that XLA.jl doesn't seem to have had continued development: https://github.com/FluxML/XLA.jl

When in doubt, piggybacking on (or at least interoperating with) what the large technology companies are investing in is probably savvy, sort of what the OP did.

XLA.jl was kind of a solution looking for a problem. If you want fast code in Julia, you can just write Julia.
That's incorrect. If you work with mid-sized neural networks and MCMC sampling, allocations start to play a significant role (And Flux.jl is bad at preallocation). Prealloc.jl does not work properly. Zygote.jl adds even more allocations to the mix...

Jax/XLA completely solves this problem. Yes, it's annoying that you have to work with a static graph but if your problem fits the description... it's great.

There's work being done to solve this in Julia. See escapeanalysis.jl and the immutable array pr in base