I'm not talking about the so-called quadratic memory requirement of the attention step, there NEVER WAS ONE.
I'm talking about a simple fact - to efficiently (cost-wise) run LLM inference you have to have a KV "cache" and its size grows (linearly) by your expected batch size and your context window length. With a large context window length it become even bigger than model weight.
I don't want to be mean, but sorry:
Sorry, read up on PagedAttention. You clearly don't know what you are talking about, please be better.
I'm not sure you're actually doing the math for these long contexts. A naked transformer generating 1k tokens with 1k prompt is spending all its time doing a bunch of forward passes to generate each token- that's what's driven your intuition. A naked transformer generating 1k tokens with 1M prompt is spending all its time generating the embeddings for the prompt (filling the kv cache), and then the iterating generation at the end is a tiny fraction of the compute even if you have to run it 1k times
I'm talking about a simple fact - to efficiently (cost-wise) run LLM inference you have to have a KV "cache" and its size grows (linearly) by your expected batch size and your context window length. With a large context window length it become even bigger than model weight.
I don't want to be mean, but sorry:
Sorry, read up on PagedAttention. You clearly don't know what you are talking about, please be better.