Hacker News new | ask | show | jobs
by janalsncm 38 days ago
The small draft model proposes a sequence of tokens d1 d2 d3.

The big target model calculates

P(d1)

P(d2|d1)

P(d3|d1 d2)

In parallel. If we were just greedy decoding it would be simple. Just stop when the draft model doesn’t predict the most likely token as judged by the target model. At that point, append the correct token from the target model and kick off both models again in parallel.

In practice we aren’t using greedy decoding. We are sampling and we need to match the target model’s distribution. To do this, we accept tokens from the draft model probabilistically, which is possible because we have the logits of both the draft model and the target at that point. The ratio of their softmax probabilities is used for this.

You are right that actually accepting tokens has to happen sequentially but that’s a heck of a lot faster than a forward pass.

2 comments

nice ... i think i get the idea - it's effectively the same / similar benefit as batching, but you're batching against your own speculated future path. Which would be pointless if you didn't have a high probability path to evaluate against - but the draft gives you that.
I'll add an expansion here. It's more useful to you locally, as you have excess compute that's generally wasted. If you're serving multiple user and trying to max output, you might cost some in this case
An obvious thing to do is that if you have enough concurrent batches to max out performance you should use those and not speculate. But if compute would be idle waiting on memory, fill the excess with speculation.
while I understand that we are computing the tokens in parallel to get the "faster" result, is there a tradeoff where we're actually utilizing more compute resources by running multiple instances of the large model? That is, while it's faster, is it more efficient?

edit: doing some more of my own research, it sounds like the bottleneck in doing it sequentially is in shifting weights around in memory, so while it uses more compute it doesn't oversubscribe compute resources because the bottleneck is not in supply of compute but in supply and speed of memory. The GPU has a massive supply of compute but sequential decoding only demands a relatively small amount of it. Time is primarily spent waiting on loading values from vram.

It’s not really multiple instances of the same model. Model weights aren’t replicated in vram. The results of multiplying k sequences through the model is larger, but that’s pretty small compared with the model weights themselves.

The bigger constraint is the target model and the draft model needing to share VRAM.