Hacker News new | ask | show | jobs
by anthonix1 629 days ago
Yeah I would suggest taking a look at PyTorch on AMD before saying stuff like "scaled_dot_product_attention is an NVIDIA CUDA kernel exposed as a PyTorch function", because that is demonstrably false.

Also, FWIW, I would suggest getting a small Llama 3.1 model training fast before trying to do a big 405B model -- faster to iterate and almost everything you'll learn on the small models will scale to the 405B.

1 comments

Thanks for the feedback! I appreciate you pointing that out. My understanding was based on the PyTorch documentation for scaled_dot_product_attention (https://pytorch.org/docs/stable/generated/torch.nn.functiona...). - "The function may call optimized kernels for improved performance when using the CUDA backend. For all other backends, the PyTorch implementation will be used."

And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.

> And was trying to make a broader point about the lack of transparency (in performance, lower-level impl) in PyTorch when running on NVIDIA vs. non-NVIDIA hardware.

I don't quite understand this argument. Lack of transparency from running PyTorch so instead we're gonna leave it all to XLA? How does this solve the "transparency" issue?

Having a common library function that is either lighting fast or dog slow depending on the hardware, is not a great position to be in.

Moreover, this will get worse as more CUDA specific features are added to PyTorch with ad-hoc fallback functions.

I guess OP is saying that XLA is more transparent in this regard, because it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise?

> it wouldn’t use functions like these and the generated comparable code would be on-pare performance wise

Perhaps if XLA generated all functions from scratch, this would be more compelling. But XLA relies very heavily on pattern-matching to common library functions (e.g. CuDNN), and these patterns will certainly work better on Nvidia GPUs than AMD GPUs.

In this way, I actually think explicitly calling the common library functions is actually much more transparent.