|
|
|
|
|
by 6d65
2012 days ago
|
|
The description says it's autograd + XLA. So I assumed it always compiles to GPUs via XLA. But, had a look in the code and jax has cublas and RoCm blas, and it looks like there is a flow where it uses the gpu directly, unless I'm missing something. Definitely worth having a closer look. Autograd via function reflection should be faster than backprop. And if it's running on AMD GPUs then it's quite intriguing. |
|