Hacker News new | ask | show | jobs
by whimsicalism 808 days ago
I think more complicated routing is absolutely going to become more common.

Specifically, I think at some point we are going to move to recursive routing, ie. pass back through a set of experts again. In the future, 'chain-of-thought' will happen internal to the model recursively

8 comments

We can name these hypothetical objects Recursive Neural Networks.
i know you're jesting but RNNs are recursive along the sequence length where I am describing recursion along the depth.
Recursive NNs are not the same as Recurrent NNs:

https://en.wikipedia.org/wiki/Recursive_neural_network

Well ish. The article above explains that Recursive-NNs are hierarchical whereas RNNs are linear. I guess the distinction is a little on the fine side.

Anyway carry on. Pedantic moment over.

The recursive neural networks described there are a failed academic project from more than a decade ago, predating modern deep learning. Basically everyone using the phrase recursive nn nowadays is probably just mispeaking for RNN. RNNs also are not linear
I don't know about "everybody nowadays" but I remember Recursive Neural Nets as an architecture introduced by Christopher Manning with the argument that it was better suited to the hierarchical structure of language than existing architectures. I did find it a bit of a bad choice of name, given that it's so closed to Recurrent Neural Nets. All this is from memory though I might check the internets later to see what I misremember.

RNNs are a large class of architectures of varying complexity, from Kallman Filters to LSTMs. It's not clear to me exactly what the wikipedia article means by "linear" but LSTMs for example treat their inputs as sequences and don't try to deconstruct them into parts, like e.g. Convolutional Neural Nets do. So maybe that's what's meant by "linear".

No opinion on the specifics of this distinction, but it's worth noting that in research, an awful lot of successful projects have their origins in failed projects of decades ago...
My experience working in machine learning academia is an overfocus on failed projects from the early 00s to 90s that really only stopped in 2020+.

We can often trace back successful projects to failed precursors, but often the people behind the successful project are not even familiar with the failed precursor and the 'connection to the past' only really occurs in retrospect. See the 'adjoint state method' and connections with backprop.

Depthwise RNN?
Like decode the next token, then adjust what you're paying attention to, then decode it again?
Isn't it the only way to, say,understand a pun?
That is exactly how LLM inference is performed, so I'm being cheeky (I'm 99% sure anyone proposing anything in this thread is someone handwaving based on limited understanding)
You would be wrong, but that is fine. Been working with attention since 2018.

Why assume I know little and leave snarky comments (and basically a repetition of the prior joke at that, subbing RNN for transformer)?

What you describe here sounds a little like the line of work centered around Universal Transformers, which basically process the input embeddings through a single transformer block multiple times with a separate module deciding when the embeddings have been cooked enough and can be pulled out of the oven so to speak.

Even more in line with the idea of "experts" there's a paper from last year on Sparse Universal Transformers in which they combine a universal transformer with sparse mixture of experts, so it's up to the gating mechanism to decide which transformer blocks and in which order are to be used in shaping the embeddings.

This really isn't my specialty but from what I gathered these are tricky to train properly, and require more overall compute during inference to reach comparable results to their vanilla transformer counterparts. It's an interesting direction nonetheless, having an upper bound on the number of computation steps per token is, in my opinion, one of the major downsides of the classical transformer architecture.

I think the reason this hasn't been done is you have no way to decide how many recursions are necessary at train time.

And if you pick a random number/try many different levels of recursion, you 'blur' the output. Ie. the output of a layer doesn't know if it should be outputting info important for the final result, or the output that is the best possible input to another round of recursion.

Yes, I think training this model would be hard. Perhaps something akin to how MoEs are trained where you impose some sort of loss distribution to encourage equitable routing, but for recursion.
Look at the human brain for useful analogies?

The default mode network does recursive/looping processing in the absence of external stimuli and world interaction. Multiple separate modules outside of the network are responsible for stopping and regulating this activity.

You could just learn the right estimated number of recursions, also passing 'backtracking'/'state' information at the next nested level. Kind of like how state space models encode extractible information via a basis function representation, you could encode extractible recursion state information into the embedding. See also transformers that can learn to recognize n-deep balanced parentheses (Dyck-n languages)
I have been thinking about this topic for some time. It might be done using the energy of the token. If it's still higher than an energy limit, then process it again, and increase the energy limit. The energy could be computed using log-sum-exp: https://openreview.net/pdf?id=Hkxzx0NtDB
This is actually how EfficientNet trains, using random truncation of the network during training. It does just fine... The game is that each layer needs to get as close as it can to good output, improving in the previous activation quality.
Attention is basically routing, these other routing schemes put a less fine-grained choice for the model, which potentially makes it easier to train
How is attention basically routing?
It routes values based on linear combinations taken from the attention map.
But all of those values are created using an MLP with the same parameters, so there is no routing to different parameters.
You have to look at it as a sequence of time steps which can interact. You can implement this interaction in many ways, such as transformer, mamba, rwkv or mlp-mixer. But the purpose is always to allow communication across time.

You use three distinct linear projections, one for queries, one for keys and one for values. From Q and K you compute the attention matrix A, and using A you construct linear combinations from V. But depending on A, for example for a token V_i there might be input from two other tokens, V_j or V_k, so information is moved between the tokens.

Think of it like an edge flow matrix
That doesn't clarify it for me. The same parameters are being used for every layer for every token. Yes, there is this differentiable lookup in attention like in MoE - but routing is about more than just differentiable lookup, it is about selecting on parameters not state.
The trendline is definitely toward increasing dynamic routing, but I suspect it's more so that MoE/MoD/MoDE enable models to embed additional facts with less superposition within their weights than enable deeper reasoning. Instead I expect deeper reasoning will come through token-wise dynamism rather than layer-wise -- e.g., this recent Quiet-STaR paper in which the model outputs throwaway rationale tokens: https://arxiv.org/abs/2403.09629
There are already some implementations out there which attempt to accomplish this!

Here's an example: https://github.com/silphendio/sliced_llama

A gist pertaining to said example: https://gist.github.com/silphendio/535cd9c1821aa1290aa10d587...

Here's a discussion about integrating this capability with ExLlama: https://github.com/turboderp/exllamav2/pull/275

And same as above but for llama.cpp: https://github.com/ggerganov/llama.cpp/issues/4718#issuecomm...

See, this is where my understanding of LLMs breaks down. I can understand one token going through the model, but I can't understand a model that has different "experts" internally.

Do you have any resources or links to help explain that concept?

The "mixture of experts" goal is to add more parameters to the model to make it more powerful, without requiring any more compute. The way this is done is by having sections of the model ("experts") that are in parallel with each other, and each token only going through one of them. Think of it like a multi-lane highway with a toll booth on each lane - each car only drives on one lane rather than using them all, so only pays one toll.

The name "experts" is a bit misleading, since each expert ("highway lane") is not really specialized in any obviously meaningful way. There is a routing/gating component in front of the experts that chooses on a token by token basis (not sentence by sentence!) which "expert" to route the token to, with the goal of roughly load balancing between the experts so that they all see the same number of tokens, and the parameters in each expert are therefore all equally utilized.

The fact that the tokens in a sentence will be somewhat arbitrarily sent through different "experts" makes it an odd kind of expertise - not directly related to the sentence as a whole! There has been experimentation with a whole bunch of routing (expert selection) schemes.

It is still just one token going through the model.

I actually think mixture-of-expert is a bit of a misnomer, the 'experts' do not really necessarily have super distinct expertise. Think of it more as how neurons activate in the brain - your entire brain doesn't light up for every query, now in neural networks the same thing happens (it doesn't fully light up for every query).

Don't really know a resource besides the seminal Noam Shazeer paper, sorry - I'm sure others have higher-level.

Most of the original MoE implementations around LLMs were in fact recursive
Could you please elaborate?
The original MoE research done by Google around LLMs involved nested transformers to scale them. It was a layered approach where at each layer you would have set of experts, generally routed to by simple heuristics, then each of those models would call into its own series of experts and combine the data in various ways.

These models were SOTA for their time

Interesting, but that isn't recursive as the sub-model cannot invoke a model higher up in the invoke graph/tree.