|
This brought up memory of Hungarian notation. I think now I will try to use it in my PyTorch code to solve the common problem I have with NN code: keeping track of tensor shapes and their meanings. B, T, E = x.size() # batch size, sequence length, embedding dimensionality
q, k, v = self.qkv(x).split(self.embedding, dim=-1)
q, k, v = map(lambda y: y.view(B, T, self.heads, E // self.heads).transpose(1, 2))
attention = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
...
vs B, T, E = bteX.size()
iHeadSize = E // self.heads
bteQ, bteK, bteV = self.qkv_E_3E(bteX).split(E, dim=-1)
bhtiQ, bhtiK, bhtiV = map(lambda y: y.view(B, T, self.heads, iHeadSize).transpose(1, 2))
bhttAttention = (bhtiQ @ bthiK.transpose(-2, -1)) * (1.0 / iHeadSize)
Looks uglier but might be easier to reason about. |