Let's just summarize the FlashAttention into the following: Att(i) computation without FA runs in O(seq_len*dk + seq_len^2)
whereas Att(i) computation with FA runs in O(seq_len^2*dk^2/SRAM_size)
Q, K, V computation remains the same. And ATTN(0,n)*Wo also remains the same.In a smaller model, with N=12, D=768, dk=64, seq_len=1k, SRAM=32KB, ..., FA optimization would roughly translate to 0.5M vs 4.5M per-head(att(i)). So ~10x improvement but in the grand scheme of things, in per-attention-layer it becomes ~91M vs ~45M so ~2x of net improvement. > This is why "prefill" is compute intensive instead of memory intensive. Yes, I think I agree and I have corrected myself elsewhere in the thread. The original thought that I actually wanted to convey in my initial comment which was somehow lost throughout the discussion is that - prefill/training will benefit from the FlashAttention/MLA but the inference will not. I can agree that the formulation "only when memory access time dominates the compute in attention implementation" was wrong. > During token generation ... memory bandwidth is the deciding factor for token generation. LLama3-70B MLP layer roughly takes 1 TFLOPS and 0.6 GB of bandwidth for 1024 tokens. Assuming that 1023 entries are taken from a KV-cache, attention layer computation for a single token will take ~0.6 GFLOPS and ~0.2 GB of bandwidth. To load the rest of the values from KV-cache at FP16 precision, it will take us 1023*0.1MB or ~1 GB. So, ~1 TFLOPS and ~1 GB of bandwidth per each Transformers layer. On hardware such as H100, this still looks like a compute-bound problem to me. OTOH on the CPU with 15 TFLOPS of compute but only <1TB/s of memory bandwidth, it becomes memory-bound problem. Or no? |
Since the number differs by roughly 1024x, maybe you forgot that you just need to work on the last decoded token for MLP, too? Because you don't need hidden state for previous tokens in Attn now.