|
|
|
|
|
by felarof
639 days ago
|
|
Oops, my calculation was wrong. Let me add an edit to the blog, thanks for pointing it out! My train step was taking 30s. And I was using a batch size of 16 and seq length of 64, making the training speed as (16*64/30) tokens per sec == 35 tokens per second (for fine-tuning in JAX eager mode). (I haven't done comparison with 8XH100) |
|
405e9 parameters
2 flops per matrix multiply per parameter
3 matrix multiplies for (forward, backward param, and backward activation) passes
batch size 16
seq length 64
1.3 petaflops per second per GPU in bfloat16
8 GPUs
30 seconds
So that’s 0.8% = (405e9 * 2 * 3 * 16 * 64 / 30) / (1.3e15 * 8)
Note that I’m ignoring the attention flops in this simplified calculation, but they would be a second order effect at this sequence length
Also note that I’m assuming full weight training, not LoRA . The result would be lower MFU if using LoRA
These MI300X results are promising functionally (it's tough to get any model this big running) but they have a long way to go on perf. It's also single node. The biggest issues I've seen on MI300X are related to scaling to multiple nodes.
EDIT: The blog seems to indicate it is using LoRA. So we should remove the backward param pass from the equation above. Backward param only applies to adaptor weights, which are much more than 10x smaller, so we set it to 0 in the approximation. So we get
0.53% = (405e9 * 2 * 2 * 16 * 64 / 30) / (1.3e15 * 8)