|
|
|
|
|
by anthonix1
638 days ago
|
|
Does JAX have its own implementations of matmul, flash attention etc? Or does it use the ROCm implementations like PyTorch does? (e.g,. hipblaslt, Composable Kernel FA etc) Not too familiar with JAX, but the abysmal PyTorch training perf on MI300x is in large part attributable to the slow perf of the ROCm libraries it is using under the hood. |
|
1. https://jax.readthedocs.io/en/latest/pallas/index.html
2. https://github.com/jax-ml/jax/blob/main/jax/experimental/pal...