|
|
|
|
|
by mattjjatgoogle
1212 days ago
|
|
If you have any particular examples in mind, and time to share them on https://github.com/google/jax/issues, we'd love to try to improve them. Improving error messages is a priority. About introspection tools, at least for runtime value debugging there is to some extent a fundamental challenge: since jax.jit stages computation out of Python (though jax.grad and jax.vmap don't), it means standard Python runtime value inspection tools, like printing and pdb, can't work under a jax.jit as the values aren't available as the Python code is executing. You can always remove the jax.jit while debugging (or use `with jax.disable_jit(): ...`), but that's not always convenient, and we need jax.jit for good performance. We recently added some runtime value debugging tools which work even with jax.jit-staged-out code (even in automatically parallelized code!), though they're not the standard introspection tools: see `jax.debug.print` and `jax.debug.breakpoint` on https://jax.readthedocs.io/en/latest/debugging/index.html and https://jax.readthedocs.io/en/latest/debugging/print_breakpo.... If you were thinking about other kinds of introspection tooling, I'd love to hear about it! |
|
That's handy, and I hadn't seen it before, thanks.
It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928). I'm not sure the exact solution, but the axis juggling and specifications were where I remember a lot of pain, and the docs (though extensive) were unclear. At times it feels like improvements are punted on in the hopes that xmap eventually fixes everything (and xmap has been in experimental for far longer than I expected).
Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out. When it came to doing some complex jax.lax.ppermute, it felt more difficult than it needed to be to debug.
Next time I encounter something particularly opaque, I'll share on the github issue tracker.