Hacker News new | ask | show | jobs
by newhouseb 1113 days ago
The algorithm learned here actually makes a lot of sense when you spend more time understanding how transformers typically work.

Namely: once you include layer normalization your model is more or less forced to find ways to represent absolute quantities in a way that won't be normalized away and a great way to achieve this is to... store things as rotations of a unit tensor! With that as your primitive, it's fairly natural to rotate around a circle to compute modular addition.

I'd be curious to explore if a different algorithm is learned if one were to stop normalizing at various points. I wouldn't be surprised if a large hurdle to mechanistic interpretability turns out to be that the models have learned complicated rotations in non-obvious coordinate spaces that are tricky to identify after the fact.

1 comments

In the model he was using he didn't use layer norm, right?
Oh good catch! The author defines a layer norm layer but then... comments it out in the actual implementation (I missed the fact that it was commented out). So that answers my second question of what happens without it.

Anecdotally in my own interpretability work (without layer norm), my models also learn rotations fairly frequently. I attributed this to the way I was doing positional embeddings (as rotations), but perhaps there's more to it.

Thinking about this more, softmax is also a form of normalization that could likely contribute to this phenomenon.