There are a bunch of frameworks built on top of Pytorch too (fastAI, lighting, torchbearer, ignite...), I don't see why this should be a problem (or at least a problem to JAX but not to Pytorch)
IMO, this is not a fair comparison because Pytorch spans a larger amount of abstraction than jax (I don't quite know how to explain it other than "spans a larger amount of abstraction").
You can do much of the jax stuff in pytorch, you can't do the high level nn.LSTM stuff in jax, you have to use like flax or objax or something.
You can do much of the jax stuff in pytorch, you can't do the high level nn.LSTM stuff in jax, you have to use like flax or objax or something.