It feels like JAX doesn't have any of the high-level APIs that PT/TF/MXNet that are vital for fast prototyping of model architectures. Is that correct?
It seems that the JAX developers are focusing their time on making the core framework better and are leaving the task of building high-level APIs to the community for now.
I suspect we'll see a few high-level APIs emerge over the next few months that explore different approaches before the community settles on a particular one.
I hope not. That's part of what makes TF so miserable - the core library didn't provide the tooling people actually needed so the community built a ton of different tools and it just made TF confusing to use.
It seems that the JAX developers are focusing their time on making the core framework better and are leaving the task of building high-level APIs to the community for now. I suspect we'll see a few high-level APIs emerge over the next few months that explore different approaches before the community settles on a particular one.