Apple supports JAX[0] along with PyTorch[1] and Tensorflow[2] on macOS with both Apple Silicon and AMD GPUs (on x86 Macs). Although, the perf isn't great. I write most of my experimental ML code in JAX on an M2 Macbook Air and then move to a proper multi-GPU Linux box for full training runs.
It's significantly faster than CPU. Something like 100x using sheet