|
|
|
|
|
by DavidSJ
953 days ago
|
|
Question for rwitten or anyone else involved in this project: I see a per-device batch size of 6 for the 16B model. With 256x199 = 50944 TPUs and a sequence length of 2048, this works out to 104M tokens per batch. This is much larger than typical for training runs of dense LMs of this size, which are usually closer to ~4M tokens per batch. Was your critical batch size really this large? In other words, did you really see a benefit as compared to a much smaller batch size (and probably many fewer TPUs)? Did you use some special learning rate schedule or optimizer to achieve this? |
|