| Thanks for taking the time to explain these. > It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928). We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it. > Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out. That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise. However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!) > Next time I encounter something particularly opaque, I'll share on the github issue tracker. Thank you for the constructive feedback! |
Higher order functions are difficult in general, and it would be fantastic to have core patterns or tools for breaking them open.