|
|
|
|
|
by cl3misch
217 days ago
|
|
Not OP. I prefer JAX for non-AI tasks in scientific computing because of the different mental model than PyTorch. In JAX, you think about functions and gradients of functions. In PyTorch you think about tensors which accumulate a gradient while being manipulated through functions. JAX just suits my way of thinking much better. I also like that jax.jit forces you to write "functional" functions free of side effects or inplace array updates. It might feel weird at first (and not every algorithm is suited for this style) but ultimately it leads to clearer and faster code. I am surprised that JIT in PyTorch gets so little attention. Maybe it's less impactful for PyTorch's usual usecase of large networks, as opposed to general scientific computing? |
|
It's not weird. It's actually the most natural way of doing things for me. You just write down your math equations as JAX and you're done.