Hacker News new | ask | show | jobs
by yklcs 785 days ago
I like JAX, and find most of the core functionality as an "accelerated NumPy" great. Ecosystem fragmentation and difficulties in interop make adopting JAX hard though.

There's too much fragmentation within the JAX NN library space, which penzai isn't helping with. I wish everyone using JAX could agree on a single set of libraries for NN, optimization, and data loading.

PyTorch code can't be called, meaning a lot of reimplementation in JAX is needed when extending and iterating on prior works, which is the case for most of research. Custom CUDA kernels are a bit fiddly too, I haven't been able to bring Gaussian Splatting to JAX yet.

3 comments

I'm curious what interop difficulties you've run into in JAX? In my experience, the JAX ecosystem is quite modular and most JAX libraries work pretty well together. Penzai's core visualization tooling should work for most JAX NN libraries out of the box, and Penzai's neural net components are compatible with existing JAX optimization libraries (like Optax) and data loaders (like tfds/seqio or grain).

(Interop with PyTorch seems more difficult, of course!)

It's mostly an ecosystem thing, being unable to use existing methods. In my experience, research goes something like

1. Milestone paper introducing novel method is published with green-field implementation

2. Bunch of papers extend milestone paper with brown-field implementation

3. Goto 1

Most things in 1 are written in PyTorch, meaning 2 also has to be in PyTorch. I know this isn't JAX's fault, but I don't think JAX's philosophy to stay unopinionated and low-level is helping. Seems like the community agreeing on a single set of DL libraries around JAX will help it gain some momentum.

That's my experience as well. PyTorch dominates the ecosystem.

Which is a shame, because JAX's approach is superior.[a]

---

[a] In my experience, anytime I've have to do anything in PyTorch that isn't well supported out-of-the-box, I've quickly found myself tinkering with Triton, which usually becomes... very frustrating. Meanwhile, JAX offers decent parallelization of anything I write in plain Python, plus really nice primitives like jax.lax.while_loop, jax.lax.associative_scan, jax.lax.select, etc. And yet, I keep using PyTorch... because of the ecosystem.

The best is not always popular. JAX idea is very like Erlang programming language.
> The best is not always popular.

I agree. Network effects routinely overpower better technology.

Another issue I've personally faced is debugging - although I am saying this from my experience from more than a yr ago, and maybe things are better now. I have used it mostly for optimization and the error messages aren't helpful.