|
|
|
|
|
by andy12_
520 days ago
|
|
The eq. 12 is a loss function to associate a given key and value in the memory MLP using test-time training with gradient-descent. The eq. 15 is simply the operation to query a value that was previously inserted in previous tokens using eq. 12. Basically, for each autoregressively processed segmented you do: 1) Test-time inference: query values from memory with eq. 15. 2) Test-time training: associate new keys and values into the memory with the loss from eq. 12. The forget and remember gates is because... well, the architecture in general is very similar to a LSTM, but using test-time gradient descent to decide what to insert to the long-term memory. |
|
Seems the implicit assumption then is that M(q) -> v 'looks like' or 'is smooth like' the dot product, otherwise 'train on keys, inference on queries' wouldn't work ? (safe assumption imo with that l2 norm & in general; unsafe if q and k are from different distributions).
Correct me if I'm wrong, but typically k and v are generated via affine projections K, V of the tokens; if M is matrix-valued and there are no forget and remember gates (to somehow approx the softmax?), then M = V K^-1