Hacker News new | ask | show | jobs
by lostmsu 172 days ago
This brought up memory of Hungarian notation. I think now I will try to use it in my PyTorch code to solve the common problem I have with NN code: keeping track of tensor shapes and their meanings.

  B, T, E = x.size() # batch size, sequence length, embedding dimensionality
  
  q, k, v = self.qkv(x).split(self.embedding, dim=-1)
  q, k, v = map(lambda y: y.view(B, T, self.heads, E // self.heads).transpose(1, 2))

  attention = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
  ...
vs

  B, T, E = bteX.size()

  iHeadSize = E // self.heads
  bteQ, bteK, bteV = self.qkv_E_3E(bteX).split(E, dim=-1)
  bhtiQ, bhtiK, bhtiV = map(lambda y: y.view(B, T, self.heads, iHeadSize).transpose(1, 2))

  bhttAttention = (bhtiQ @ bthiK.transpose(-2, -1)) * (1.0 / iHeadSize)
Looks uglier but might be easier to reason about.