|
To add to what others have said here, this is due to the memory hierarchy. GPUS have different kinds of memory, there's fast-but-small memory and slow-but-large memory. Conceptually, you can imagine the process of LLM inference as transferring some weights from slow memory to fast memory, doing some calculations on those weights, discarding them from fast memory once the computation is done, loading in the next portion, and so on, until you're fully done. You can do calculations for multiple tokens in parallel, but to calculate what token n is, you need to already know all the previous tokens 1..(n-1). Therefore, if you don't have spec decoding, you go one token at a time. If you do, you assume that the next tokens actually are what the smaller model gave you, discarding the results in case you were wrong. With speculative decoding, you can basically load the weights once and apply them to multiple tokens instead of just one, because of the assumption of what the next tokens are that you're making. This decreases the amount of data that has to go between slow and fast memory. As the decode stage[1] is bottlenecked by memory bandwidth and not compute speed, more efficient use of this bandwidth increases your token generation speed. As another poster said, this idea is closely related to batching. In batching, you re-use the same weights to serve multiple requests. In speculative decoding, you re-use them to accelerate a single one. If you have many users, care only about how many tokens per second your GPUs produce in general, and don't care at all about per-user speed, speculative decoding won't do anything for you. [1] There are two stages in LLM inference: prefill and decode. In prefill, you do calculations on the tokens of the prompt, prefilling the KV cache to accelerate attention computations at decode time. Because you have access to all the tokens of the prompt, you can process everything in parallel and use your weights very efficiently. Your bottleneck here is the computation units and not memory bandwidth. In decode, you don't know what your future tokens will be, so you can only go one at a time as explained above. In a way, speculative decoding turns decode into a little prefill. |