Hacker News new | ask | show | jobs
by ChrisRackauckas 1656 days ago
XLA only optimizes quasi-static code, which does not include adaptive numerical solvers like those for ODEs. It's a generally good assumption for ML though, but there are ways to break it. I wrote a piece showcasing some ideas around that: https://www.stochasticlifestyle.com/useful-algorithms-that-a...
1 comments

IIUC people have already run MD (which is the field I used to work in) on XLA, https://twitter.com/sschoenholz/status/1334997741185814530 In these cases it's almost always better (unless you are a numerical genius) to port to the engine, than to try to make a better algorithm that runs on a smaller engine.
Yes, that has nothing to do with what I just said though. Of course MD is fine because symplectic ODE solvers cannot generally have adaptivity (without tricky and very expensive handling of `t` inside of the Hamiltonian which nobody does because it's still an active research topic how to make it computationally viable). So MD gets a quasi-static code which XLA is fine with optimizing. I was explicitly talking about the non-quasi-static cases.
I've worked in ODEs for 20+ years and I don't think that non-quasi-static solvers have really ever come up. Are these commonly used? IE, how much CPU/GPU/TPU time is spent on them globally and how useful are they?
Have you used almost any ODE solver? Almost every single one uses embedded methods to adapt time steps. ode23, ode45, ode23t, ode23tb, ode15s, LSODE, LSODA, radau, rodas, VODE, CVODE, ... even for DAEs you have DASSL, IDA, ... I can keep going but it's just listing every ODE solver code out there. Once you do that then the computation is dependent on values and thus the full compute is not determined by the input sizes, which is something known to be blocking the full usage in Jax because of XLA limitations (for example the implementation of dense output).