Hacker News new | ask | show | jobs
by swyx 916 days ago
things I'd like a non-ML-researcher explanation of about Mamba:

1. what is the overall insight of state space models beyond transformers? (i know this is somewhat covered in the paper but still a bit inaccessible)

2. what was the incremental innovation/result that is making Mamba more successful/interesting than its predecessors? (S4, H3, Monarch etc)

3. what are the implications beyond subquadratic scaling of context? say if i don't really care about context length > 100k tokens. what other benefits are there - for example, is Mamba potentially more compute-efficient to train for a similar size of model/dataset?

just offering 3 prompts for knowledgeable people to drop some alpha

8 comments

My IQ is orders of magnitude lower than the authors of the paper, but I did my best to work through it anyway. I studied CE and have the basic control theory background and undergrad level discrete time systems intuition. It would take much additional studying to understand state space models enough to really parse this paper. But I tried anyway. Take my comment here with a big grain of salt.

The overall insight of Mamba is to solve a longstanding problem with state space models. They are good at compressing the input context, but the compression of input into a hidden state erases information needed to make use of the context effectively as Transformers do.

Their solution to this problem is to create what they call a selection mechanism. The mechanism is input-dependent, allowing the model to adjust its output at each step as the input changes. How they do this is by making a few of the state space variables input-dependent instead of input-invariant. They choose a few of the state space variables and attach linear layers and such to project the input onto the state space variable at each time step. The linear layers (etc) are obviously trained so that they know how to transform the input appropriately so that the model spits out useful output.

But making the state space variables input dependent creates a problem in terms of computation overhead. They fix the computation problem by designing a machine architecture-aware algorithm that makes the most of modern GPU memory architecture, avoiding moving things in and out of HBM as much as possible.

Tri Dao came up with Flash Attention, which is basically a way to use hardware more efficiently in a Transformer. So this is his jam 100%.

I know this doesn’t add much to understanding the paper, but hopefully it’s better than nothing.

Is this similar to subset selection with the concrete distribution?
I don’t know enough to answer your question, sadly.
1. Attention is quadratic with context length; RNN with gating (LSTM, GRU, etc) are linear, as are all these new architectures. Early RNN used gating to avoid exploding gradients, these new ideas use theory from dynamical systems that guarantees stability so the gating can focus on memory, rather than solving two problems at once.

2. The models released in the last couple of weeks running up to neurIPS23 (Mamba and Based) included a multi-query associative recall (MQAR) and data-dependence in the gating/selection inspired by multi-headed attention. It turned out these were the main missing ingredients compared to earlier state-space (Hyena and earlier) architectures and made these new models as good as attention in associative recall tasks, and potentially even slightly better than attention in other non-lookup tasks. Of course the huge detail in mamba is the efficient implementation on CUDA; without it the architecture may not make much sense for tasks where transformers are already appropriate.

3. If one does not have to worry too much about context length, a lot of new domains open up: DNA-sequence analysis is a linear task with long dependence; think of analyzing images, videos, or higher dimensional info in terms of streams of tokens (scan the pixels in the way of an old CRT monitor). The early dreams of AI included a continuously evolving single learning trajectory of an agent interacting with an environment continuously, so maybe such dreams will be easier to realize with these infinite-context-length models.

bonus: you didn't ask for it, but as of today the downstream applications of these models for important/practical tasks are largely untested/untuned compared to the rather mature applications of attention, so there may be a little delay before people figure out all the tricks for how to use large pre-trained models of these types. The analogy to the old RNN helps to a degree, but people had super specialized to attention and transformers the last 5 years, so there is a lot of momentum in favor of transformers.

Can you cite what's the "Based" paper in here.
This blog about “Based” came out just before neurIPS: https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-b...
And this is the zoology paper: https://arxiv.org/abs/2312.04927
> is Mamba potentially more compute-efficient to train for a similar size of model/dataset?

I would like to understand it too as well ...

Here is the citation from original paper:

"Computation. After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3). Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time)."

So the training is parallelizable, like in RetNet with parallel forward mode. By default inference is done in the recurrent mode, to have a longest possible context. No chunking available, so it is difficult for me to say how much RAM and VRAM it will consume during the inference ...

I did some minimal testing, mamba uses about 60% of VRAM in comparison to RetNet (parallel forward mode) with the model of the same size and the vocabulary of same size during inference.
I think this video is exactly what you’re looking for.

He explains the paper but also gives a lot of context, how it fits into the big picture, etc.

It’s actual kind of exciting hearing the plot unfold.

https://youtu.be/ouF-H35atOY?si=y2Ckp9MCFd7ulLL3

AFAIK mamba is continuation of the SSM research, which is basically something called long-convolution.

Instead of doing quadratic attention (computing how much each token attends to every other token) you just "somehow" compute a long (same length as input) convolution kernel, and then you apply the conv1d.

Again, from my limited understanding, it's bit related to applying FFT, doing some matmul and then IFFT back. We know that this works but it's slow. But there are many ways to compute FFT and one of them is with something called butterfly matrices. I think it's just approximation but it's good enough and it's very fast/efficient on current hardware.

To put this in context, quadratic sounds bad, but in practice, subquadratic algos are often slower because of hw limitations. So while there was a lot of excitement about SSM it's not so easy to say that llama is over now. Also, we don't know if mamba will scale up, and the only way to know that is to actually pay few millions for training. But I am optimistic.

Another interesting model from subquadratic family is RWKV. Worth checking, but I think you had a podcast about it :)

BTW: I am self-thought and I've only skimmed the paper some time ago so I might be very wrong.

BTW2: Another thing with attention is that there's usually KV-cache, which helps a lot with performance, and I think you cannot do that with mamba.

my loose understanding

1) transformers create an input x input size attention matrix that is unnecessarily large. state space models somehow compress this.

2) "The main difference is simply making several parameters [in the state space model] functions of the input"

3) i think it might be more sample efficient (requires less data)

Re 3) Even if you don't care about long context length, Mamba is much cheaper per token of auto-regressive output. Each token has to only compute the next step of a linear RNN, the transformer has to attend back over all previous outputs, which rapidly grows in cost and memory.
For 2, Mamba makes some A B C weights that in S4 are time invariant become functions of the input, which makes it more powerful.