Oh so float8's L2 Norm from float32 is around I think 1e-4, whilst float16 is 1e-6. Sadly attention is quite sensitive. There are some hybrid methods which just before the attention kernel which is done in fp8, upcasts the Q and K from the RoPE kernel to become float16, then also leaves V to be in float8. Everything is done in fp8 on the fly, and the output is fp8. This makes errors go to 1e-6.
Yes, but it's a bit more complicated. There are 2 FP8 formats: E5M2 and E4M3.
E5M2 is like an IEEE 754. But to compensate the smaller exponent, "E4M3’s dynamic range is extended by not representing infinities and having only one mantissa bit-pattern for NaNs".
Some people reported E4M3 is better for the forward pass (small range, more precision) and E5M2 is better for the backward pass (bigger range, less precision). And most implementations have some sort of scaling or other math tricks to shrink the error.
Fair points! Ye Pytorch's fp8 experimental support does scaling of the gradients. Interesting point on a larger range for the forward pass, and a small range for the gradients! I did not know that - so learnt something today!! Thanks! I'll definitely read that paper!