|
|
|
|
|
by edude03
708 days ago
|
|
How much is the flash attention algorithm tied to the hardware? For example, in this announcement they mention taking advantage of the async capabilities of the H100 GPUs which I assume means you won't get those speedups on non H series card. Two, the actual flash attention library requires CUDA, although the algorithm has apparently?[^0] been ported to metal. I would imagine if the algorithm was literally just a pure function it could be implemented for any GPU/ML framework? [0]: https://github.com/philipturner/metal-flash-attention |
|
> https://github.com/karpathy/nanoGPT/blob/master/model.py#L45
Karpathy's nanoGPT calling flash attention by checking if torch.nn.functional.scaled_dot_product_attention exists
> https://pytorch.org/docs/stable/generated/torch.nn.functiona...
Looking at the docs, in reality, most of the time you want this to call out to FA2 which optimizes the kernals on the device to split ops on the Softmax of the triangular matrix as well as reduce moving unnecessary batches of floating point numbers back and forth from the GPU to the CPU.
> https://arxiv.org/pdf/2307.08691
The paper for FA2 almost entirely considers itself through the hardware it's running on.