Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily.
The main benefit in my experience is that it’s much easier to do distributed computations in JAX. It has a much nicer API. For single device computing there’s no advantage either way.