Hacker News new | ask | show | jobs
by WithinReason 616 days ago
Hmmm, this could be expressed as 2 consecutive attentions in a residual branch:

Simplified differential T. looks like: (softmax(Q₁K₁) − λ softmax(Q₂K₂)) V

You can factor this into:

    x = softmax(Q₁K₁)V
    x += -λ softmax(Q₂K₂)V
which is like 2 subsequent regular attentions added that are sharing V
2 comments

You could also extrapolate this into more than two terms by squinting your eyes and saying that λ ∈ {1, -1} is close enough to λi ∈R^d ∣ ∥λi ∥=1. No idea if it would result in better performance, but that's research babyyyy!
Now I'm wondering, isn't there usually a `num_heads x value_dim -> model_dim` projection that goes after a MHA? The W in `softmax(QK)VW`? That one can play the role of this subtraction in a vanilla transformer, no? So I wonder what kind of advantage does splitting things up like this bring.