Hacker News new | ask | show | jobs
by refibrillator 483 days ago
vLLM supports MLA for Deepseek models as of 3 weeks ago. 3x higher generation throughput and 10x token memory capacity.

https://github.com/vllm-project/vllm/releases/tag/v0.7.1

MHA is still faster in low QPS regime apparently.

https://neuralmagic.com/blog/enhancing-deepseek-models-with-...

Also published this month was theoretical proof showing that for the same KV Cache overhead, MLA consistently offers greater expressive power than GQA. Furthermore, widely used GQA-based pre-trained models (e.g. LLaMA, Qwen, Mixtral) can be converted into MLA-based models.

https://arxiv.org/pdf/2502.07864

4 comments

For future readers, note that those 3x and 10x figures are compared to vLLM's own previous release, and NOT compared to Deepseek's implementation.

I am very curious to see how well-optimized Deepseek's code is compared to leading LLM serving softwares like vLLM or SGLang.

It's great to see vLLM getting faster/better for DeepSeek. I tested vLLM vs SGLang a couple weeks ago and SGLang's DeepSeek support was much better/faster (on 2 x p5 H100 nodes). It's great that no one's standing still, I saw this recent AMD article that reported SGLang perf on MI300X has increased by 4X over the past couple weeks: https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR...

(w/ the extra memory V3/R1 fits on a single MI300X or H200 node)

It'll be interesting to see if either project can take advantage/get any benefits from this FlashMLA implementation.

Pretty significant improvements. However, my back on the napkin math suggests that MLA, FlashAttention and similar optimizations will provide the benefits only when memory access time dominates the compute in attention implementation? Those would be the prefill-phase (or TTFT) and training (when batch_size >> 1) but not the decode phase (inference)?
You have it backwards.

Training and prefill are compute bound. Decode is memory bound. FlashAttention massively increases the arithmetic intensity of naive MHA, such that you can remain compute bound at lower batch sizes during decode.

> Decode is memory bound.

> FlashAttention ... such that you can remain compute bound at lower batch sizes during decode.

So, which one is it then?

It depends on the batch size and the accelerator you're running on! Decode is *typically* memory bound unless you can hit high batch sizes (in the hundreds), which is hard during serving due to the contention between batch size and low TTFT.

https://jax-ml.github.io/scaling-book/inference/ - good read!

You've got it backwards. After FlashAttention, it's the decoding part being bound mainly by memory access. With FA as long as you have enough batch size you can push training/prefill to be compute-bound.
I don't think I got it backwards, I believe what I said is correct - FA does not improve inference time.

From the authors of FlashAttention:

> This [decoding] operation has been optimized with FlashAttention (v1 and v2 recently) in the training case, where the bottleneck is the memory bandwidth to read and write the intermediate results

And then they continue with:

> However, these optimizations don’t apply directly to the inference case, because the bottlenecks are different. For training, FlashAttention parallelizes across the batch size and query length dimensions. During inference, the query length is typically 1 ... With a batch size of 1, FlashAttention will use less than 1% of the GPU!

And then they come up with a different proposal, FlashDecoding, that optimizes for inference time:

> Our new approach Flash-Decoding is based on FlashAttention, and adds a new parallelization dimension: the keys/values sequence length. It combines the benefits of the 2 approaches from above. Like FlashAttention, it stores very little extra data to global memory, however it fully utilizes the GPU even when the batch size is small, as long as the context length is large enough.

Link: https://crfm.stanford.edu/2023/10/12/flashdecoding.html

You're confusing two things.

Classic softmax attention aka Softmax(Q K^T/sqrt(d_k))V consists of two matrix multiplications.

This means QK^T=O and then softmax(O/sqrt(d_k)V.

The matrix O is quadratic with respect to the number of input tokens. Writing the O matrix to main memory is bound by the maximum bandwidth of your memory.

Then it has to be read out again to be multiplied against V.

What flash attention does is change the algorithm. Flash attention is numerically similar to softmax attention, but not equivalent. The changed algorithm allows you to fuse the independent kernels.

Instead of writing out the O matrix to main memory, its softmax is calculated against V immediately. The double memory roundtrip is now gone. This in itself does not change the fact that both softmax attention and flash attention are quadratic with respect to the input, but it sure as hell improves the speed of "prefill".

If you tile the Q, K, V matrices into n blocks each, you will still have to load O(n^2) blocks.

But here is the thing. Matrix multiplication is an operation with a significant amount of shared data. This means the multipliers calculating the dot products are being fed from the same flip flops, or the data is shifted around via a systolic array. You end up in a situation with an insignificant memory load, but a massive amount of arithmetic.

In addition to that, you have all the tokens already, so the MLPs at the end of the layer can be processed as GEMM instead of GEMV.

This is why "prefill" is compute intensive instead of memory intensive.

During token generation, you need to perform attention for the next token, with all the tokens already in the KV cache. You load n entries from the KV cache, then do GEMV on the MLP and you have to do this over and over again in a sequential fashion. This means that memory bandwidth is the deciding factor for token generation.

Now here is a caveat: if SRAM is limited Vs your TOPS, then it is possible that even flash attention is memory bound, but for a different reason. It's memory bound, because the maximum tile size that can be held in SRAM can be processed faster than it takes to load it from system memory or VRAM and you are performing a quadratic amount of tile loading operations. This is only noticeable near the extreme top end of context lengths between 32k and 128k tokens.

Let's just summarize the FlashAttention into the following: Att(i) computation without FA runs in

   O(seq_len*dk + seq_len^2)
whereas Att(i) computation with FA runs in

   O(seq_len^2*dk^2/SRAM_size)
Q, K, V computation remains the same. And ATTN(0,n)*Wo also remains the same.

In a smaller model, with N=12, D=768, dk=64, seq_len=1k, SRAM=32KB, ..., FA optimization would roughly translate to 0.5M vs 4.5M per-head(att(i)). So ~10x improvement but in the grand scheme of things, in per-attention-layer it becomes ~91M vs ~45M so ~2x of net improvement.

> This is why "prefill" is compute intensive instead of memory intensive.

Yes, I think I agree and I have corrected myself elsewhere in the thread. The original thought that I actually wanted to convey in my initial comment which was somehow lost throughout the discussion is that - prefill/training will benefit from the FlashAttention/MLA but the inference will not. I can agree that the formulation "only when memory access time dominates the compute in attention implementation" was wrong.

> During token generation ... memory bandwidth is the deciding factor for token generation.

LLama3-70B MLP layer roughly takes 1 TFLOPS and 0.6 GB of bandwidth for 1024 tokens. Assuming that 1023 entries are taken from a KV-cache, attention layer computation for a single token will take ~0.6 GFLOPS and ~0.2 GB of bandwidth. To load the rest of the values from KV-cache at FP16 precision, it will take us 1023*0.1MB or ~1 GB.

So, ~1 TFLOPS and ~1 GB of bandwidth per each Transformers layer. On hardware such as H100, this still looks like a compute-bound problem to me. OTOH on the CPU with 15 TFLOPS of compute but only <1TB/s of memory bandwidth, it becomes memory-bound problem. Or no?

For Llama 3 70B, batch size = 1, each MLP layer roughly takes 1x8192x26872x2 + 1x8192x26872x2 + 1x26872x8192x2 FLOPS ~= 1.31 GFLOPS, instead of ~1 TFLOPS.

Since the number differs by roughly 1024x, maybe you forgot that you just need to work on the last decoded token for MLP, too? Because you don't need hidden state for previous tokens in Attn now.

That's correct, because FA can't turn inference time from memory-access bound into compute-bound. But your claim on that decoding is compute-bound is plainly wrong.

FA, compared to naive implementation, made training / prefill (i.e. when you can have multiple tokens in the same sequence visible) compute-bound instead of memory-access bound.

So, currently, on MHA/GQA, with Flash Attention, training/prefill is compute-bound, whereas decoding is memory-access-bound.

Before FA, both prefill / decode are bound by memory-access. FA solved the problem of training/prefill. But because kvcache is large, decoding is inherently bound by memory-access.

Our goal is always to make everything compute-bound.

> But your claim on that decoding is compute-bound is plainly wrong.

I did not say anything like that? What I said is that FlashAttention and arguably MLA will not make any significant gains in the inference time. And this is true.

Also, FWIW there are certainly model shapes that are compute-bound in the decode phase so saying that decoding is universally inherently bound by memory access is what is plain wrong, if I were to use your dictionary.

Apologize if I got it wrong, but:

> MLA, FlashAttention and similar optimizations will provide the benefits only when memory access time dominates

> Those would be [...] not the decode phase

This does sound like you are saying that memory access time does NOT dominate during the decode phase. But it does.

Reading your quotes, it looks like maybe you are talking about GPU utilization issues? (i.e. not launching enough threads). Due to the parallelization strategy of the original FA it indeed does not even keep the GPU busy if q*bs is too small. But this is not an inherent limitation of FA-style kernels and can be solved and people did solve it. Or you simply batch more. Now you can keep the GPUs busy at 100% waiting for memory access, but memory access time still dominates, hence "memory-access-bound". And here comes MLA.

> FWIW there are certainly model shapes that are compute-bound in the decode phase

Yeah. But so far all I read don't really work ("work" means being at least just slightly worse than alternatives) under same wall-clock time compute budget. Do you have any pointer to a working example, even on smaller 3B-ish models?

... and batching does not help, you batch more requests and get more kvcache to load, still memory-access bound.

MLA made it possible to cache a smaller form of k/v, mitigating (but not completely solve, on shorter context & smaller batches it's still memory-access bound) the problem.

I also just read that paper. But I wonder, even though MLA is strictly more powerful, do you really gain by that in experiments? This paper doesn't really do too much experimental comparisons. GQA on the other side should still be faster (no need to an extra linear transformation).