Hacker News new | ask | show | jobs
by rahimnathwani 2176 days ago
I'm not talking about inference on a 600B parameter model. GP said they can't do inference on a 32-layer, 2048 neurons-per-layer network. Let's assume every layer is fully connected. So each neuron will have 2048 parameters. So that's 32 * 2048 * 2048 parameters. That's 132MM parameters in 11GB of RAM, or 82 bytes per parameter. If each parameter is 4 bytes (that seems like a lot of precision), plus 4 bytes per calculated value, you're still only using 10% of the GPU's RAM. You should be able to do inference on a batch of 16-20 examples at a time.

What have I missed?

1 comments

2048 neurons per layer isn't really an accurate description, what he means is 2048 dimensional embeddings at each layer. The actual multihead attention layers in a transformer are not just feed forward 2048*2048, but actually have many more parameters. That's why there's 600B total.