|
|
|
|
|
by yklcs
785 days ago
|
|
I like JAX, and find most of the core functionality as an "accelerated NumPy" great.
Ecosystem fragmentation and difficulties in interop make adopting JAX hard though. There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading. PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet. |
|
(Interop with PyTorch seems more difficult, of course!)