| This is a great set of comments/questions! To try and answer this a bit briefly: The input string is tokenized into a sequence of token indices (integers) as the first step of processing the input. For example, "Hello World" is tokenized to: [15496, 2159]
The first step in a transformer network is to embed the tokens. Each token index is mapped to a (learned or fixed) embedding (a vector of floats) via the embeddings table. The Embeddings module from PyTorch is commonly used. After mapping, the matrix of embeddings will look something like: [[-0.147, 2.861, ..., -0.447],
[-0.517, -0.698, ..., -0.558]]
where the number of columns is the model dimension.A single transformer block takes a matrix of embeddings and transforms them to a matrix of identical dimensions. An important property of the block is that if you reorder the rows of the matrix (which can be done by reordering the input tokens), the output will be reordered but otherwise identical too. (The formal name for this is permutation equivariance). In problems related to language it seems inappropriate to have the order of tokens not matter, so to solve for this we need to adjust the embeddings of the tokens initially based on their position. There are a few common ways you might see this done, but they broadly work by assigning fixed or learned embeddings to each position in the input token sequence. These embeddings can be added to our matrix above so that the first row gets the embedding for the first position added to it, the second row gets the embedding for the second position, and so on. Now if the tokens are reordered, the combined embedding matrix will not be the same. Alternatively, these embeddings can be concatenated horizontally to our matrix: this guarantees the positional information is kept entirely separate from the linguistic (at the cost of having a larger model dimension). I put together this repository at the end of last year to better help visualize the internals of a transformer block when applied to a toy problem: https://github.com/rstebbing/workshop/tree/main/experiments/.... It is not super long, and the point is to try and better distinguish between the quantities you referred to by seeing them (which is possible when embeddings are in a low dimension). I hope this helps! |
Yes, the entire description is helpful, but I especially appreciate this validation that concatenating the position encoding is a valid option.
I've been thinking a lot about aggregation functions, usually summation since it's the most basic aggregation function. After adding the token embedding and the positional encoding together, it seems information has been lost, because the resulting sum cannot be separated back into the original values. And yet, that seems to be what they do in most transformers, so it must be worth the trade-off.
It reminds me of being a kid, when you first realize that zipping a file produces a smaller file and you think "well, what if I zip the zip file?" At first you wonder if you can eventually compress everything down to a single byte. I wonder the same with aggregation / summation, "if I can add the position to the embedding, and things still work, can I just keep adding things together until I have a single number?" Obviously there are some limits, but I'm not sure where those are. Maybe nobody knows? I'm hoping to study linear algebra more and perhaps I will find some answers there?