Hacker News new | ask | show | jobs
by JBits 1000 days ago
On the other hand, if I wanted some scientific NumPy code to run on the GPU, I think rewriting it in JAX would probably be a better choice than PyTorch.
1 comments

In my experience, the answer comes down to "does your code use classes liberally?"

If no, you're just passing things between functions, then go ahead with Jax! But converting larger codebases with classes is just significantly better with PyTorch even if they use different method names etc.

I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)

You might like Equinox (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.

Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.