|
|
|
|
|
by zmmmmm
45 days ago
|
|
> it's fast to check that they are actually correct with the main model because you can run the checks in parallel. Can you give an intuition as to why it's faster? I would have thought regardless how many you run in parallel, the successful check has to execute the full model to generate the full sequence so you will have exactly the same time needed? Or is it by process of elimination so it terminates early once it eliminates the non-viable choices? (in which case, how do you guarantee the correct output was speculatively generated at all to be the last survivor?) |
|
The big target model calculates
P(d1)
P(d2|d1)
P(d3|d1 d2)
In parallel. If we were just greedy decoding it would be simple. Just stop when the draft model doesn’t predict the most likely token as judged by the target model. At that point, append the correct token from the target model and kick off both models again in parallel.
In practice we aren’t using greedy decoding. We are sampling and we need to match the target model’s distribution. To do this, we accept tokens from the draft model probabilistically, which is possible because we have the logits of both the draft model and the target at that point. The ratio of their softmax probabilities is used for this.
You are right that actually accepting tokens has to happen sequentially but that’s a heck of a lot faster than a forward pass.