|
|
|
|
|
by 6gvONxR4sf7o
1216 days ago
|
|
> with jax.disable_jit(): ... That's handy, and I hadn't seen it before, thanks. It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928). I'm not sure the exact solution, but the axis juggling and specifications were where I remember a lot of pain, and the docs (though extensive) were unclear. At times it feels like improvements are punted on in the hopes that xmap eventually fixes everything (and xmap has been in experimental for far longer than I expected). Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out. When it came to doing some complex jax.lax.ppermute, it felt more difficult than it needed to be to debug. Next time I encounter something particularly opaque, I'll share on the github issue tracker. |
|
> It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928).
We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it.
> Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out.
That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise.
However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!)
> Next time I encounter something particularly opaque, I'll share on the github issue tracker.
Thank you for the constructive feedback!