Hacker News new | ask | show | jobs
by subhrm 998 days ago
JAX GPU support is limited to Linux only. Even the WSL2 support is experimental. https://jax.readthedocs.io/en/latest/installation.html#suppo...
3 comments

Apple supports JAX[0] along with PyTorch[1] and Tensorflow[2] on macOS with both Apple Silicon and AMD GPUs (on x86 Macs). Although, the perf isn't great. I write most of my experimental ML code in JAX on an M2 Macbook Air and then move to a proper multi-GPU Linux box for full training runs.

[0]: https://developer.apple.com/metal/jax/

[1]: https://developer.apple.com/metal/pytorch/

[2]: https://developer.apple.com/metal/tensorflow-plugin/

Pytorch on my M2 max using the MPS backend has pretty decent performance to be honest?

It's significantly faster than CPU. Something like 100x using sheet

Is there a specific reason why Windows is not supported?
Presumably because the Google cloud doesn't run on Windows. Well, nothin HPC related runs Windows.
Life science industry uses plenty of Windows, including HPC workloads.
We ship Windows CPU only at the moment.

We don't support Windows GPU because we haven't had the engineer bandwidth to support it well.

We recommend WSL2 for GPU on Windows at the moment because that is a compromise: it allows CUDA support, without us having to support another release variant.

But we welcome community contributions!

Because no one has done the work to add it... Could be you!
They don't build on Windows at all, as well.
Not true!

We release Windows CPU wheels (https://pypi.org/project/jaxlib/#files). So JAX on CPU works great on Windows.

We don't release Windows GPU wheels at the moment, but that's because we're a small team and none of us use Windows personally. We welcome contributions!

(I verified that the Windows CUDA GPU support built as recently as two weeks ago, but I don't have the ability to test that it works.)

We recommend WSL2 because that's just using our existing Linux CUDA release.

Oops so sorry. But this is recent isn't it? I thought it was actually due to XLA/Bazel not supporting it?
Yes, we made this more formally supported recently.

We felt that Windows CPU support was important so everyone can run JAX, even if it's not always the most-accelerated version of JAX. And we got some great PRs from the community that helped fix a few open issues.

Very nice! I just installed, I hope to eventually contribute down the line, especially in terms of custom operators. They weren't even document until recently, and there's still quite some work to add them.