|
|
|
|
|
by kouteiheika
64 days ago
|
|
> You can only give it a try, but don't get your hopes high on a large context. You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel. For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes. |
|
https://pytorch.org/blog/peak-performance-minimized-memory/
Is this the same thing as you discuss above?