| Token count. But tokens per batch is a bad metric. I learned this the hard way through experience. It turns out that what matters is step count. Increasing batch size makes the model train faster, but increasing seq length from 1024 to 2048 doesn’t make it train twice as fast. So saying 104M tokens rather than batch size 50k x 6 is misleading to yourself. (One of the most surprising aspects of learning ML was how easy it is for me to trick myself in various silly ways like that.) The mental model to make this easy to remember: progress happens in discrete quantities called steps. Training on 104M tokens per step means it’s embedding 104M tokens of knowledge every step. This isn’t the same thing as requiring an special case optimizer due to large batch sizes — sequence length adds "knowledge bandwidth", but otherwise doesn’t mess with the training dynamics. As far as large batch optimizers, there’s LARS, which google used for their MLPerf results. I imagine they stuck with that. It creates a per-layer confidence metric, so that when the massive batch size makes a massive change, it dampens the change to smooth out the effect across the network. And since it’s a multiply, the shape of the gradient (by "shape" I mean in 3D space, where the Z axis is the intensity of the gradient) remains the same, so it doesn’t harm any knowledge transfer. It’s purely a stabilization aid. Kind of weird I remember that after three years. |