Yep i started off with trying to get it to work with pytorch (https://github.com/bkkaggle/lm-training-research-project/blo...) then with pt-lightning but the whole 1 user VM per TPU board limitation in pytorch-xla 7-8 months ago made me switch over to TF
heh. I've been using jax for a couple of months and its been a pretty nice replacement of both pt and tf. it feels like what a ml framework would look like if it were built around easy scaling and dev friendliness.