Wait wat, jax and also pytorch is used in a lot more areas then NN's.
Jax is even consider to do better in that department in terms on performance then all of julia so wat are u talking about
GP makes a fair point about JAX still requiring a limited subset of Python though (mostly control flow stuff). Also, there's really no in-library way to add new kernels. This doesn't matter for most ML people but is absolutely important in other domains. So Numba/Julia/Fortran are "better in that department in terms on performance" than JAX because the latter doesn't even support said functionality.