On the other hand, if I wanted some scientific NumPy code to run on the GPU, I think rewriting it in JAX would probably be a better choice than PyTorch.
In my experience, the answer comes down to "does your code use classes liberally?"
If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.
I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)
If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.