can't you fix a lot of the nonphysical issues by using better integration schemes? that might be hard in Jax though. From what I know, it's options for better numerical stability are pretty limited.
No, in fact, you want to go lower order with lower precision. The real answer is that if the solution is in the chaotic regime then maybe Float16 is fine because you'll be dominated by other numerical errors anyways (if you're also making sure you have adequate conservation so the solution doesn't explode in some way), but if you're not in the chaotic regime then even Float32 is pushing it in many cases (i.e. it better be non-stiff as stiffness pretty much guernetees operations which span beyond Float32 relative epsilon). So it's a case-dependent topic and not something that has an easy answer, though the case for Float16 is rather small.
(We had some small tests generating TPU ODE solver code from Julia and showcased some rather bizarre stuff back when Keno was working on it, but never wrote a post summarizing all of it)
(We had some small tests generating TPU ODE solver code from Julia and showcased some rather bizarre stuff back when Keno was working on it, but never wrote a post summarizing all of it)