| > But Transformers have one core problem. In a transformer, every token can look back at every previous token when making predictions. Lately I've been wondering... is this a problem, or a strength? It might be a fallacy to compare how LLMs "think" with how humans think. But humor me for a second. When you are speaking, each time you emit a word, you are not attending to every previous word in your sentence (like transformers), rather you have a state in your mind that represents the grammar and concepts, which is continuously updated as you speak (more similar to SSMs). Similarly, when you read a book, every time you read a word, you are not attending to every previous word in the book. Your model of "the book" is rather a fuzzy/approximate state that is updated with new information every time a new word appears. Right? (I'm sorry I know this is very handwavy and psuedoscientific but bear with me). Ok, so if (big if) you feel like the above is true, then to match human-type language modelling, SSMs seem more human-like than transformers. BUT... then aren't transformers strictly better in terms of accuracy? Because a transformer never "forgets" information, as long as it is within the context window, because it revisits that information every time it emits a new token. So let's say we can remove the "quadratic attention" problem of transformers with SSMs. That's a nice training/inference performance boost. But... look at where we got with "naive" attention. GPT 4, Claude 3. It's not like we're hitting a wall with quadratic attention. It's absurdly more expensive than SSMs, but GPUs certainly aren't getting slower. If all AI work stops now, and only hardware improves, it wouldn't be long until GPT4 could run on local hardware, right, provided Moore's law? /end rant, not really sure what my point was, I'm not against SSMs (they're cool) but rather I'm wondering if the SOTA will ever be SSM when attention is so damn good |
It probably depends. But an idea I've been playing with: because transformers have such a strong ability for recall during inference, they might be introducing a strong inductive bias for memorization as opposed to generalization. Why bother to build a complete world model when you can just attend to the answer? The global minimum in loss (at least for the training dataset) would use those memorizing and interpolating circuits over those that generalize well. This seems consistent with LLMs as they exist today: superhuman at recall, very mediocre at reasoning. Though, for what it's worth, existing SSSMs haven't yet shown they can outperform (or even match) transformers when it comes to reasoning.
If this hypothesis were true, you might expect to see grokking in state space models more quickly than in transformer models.
(Even if it's hard to train transformers to generalize, superhuman recall is still incredibly valuable, and likely a hybrid system would offer the best of both worlds.)