Hacker News new | ask | show | jobs
by AlanSE 1058 days ago
Yeah, good to bring it back to the original point. Reading the article felt exciting, but in hindsight I am now missing a key detail.

The equations all seem to be matrix operations with a fixed number of rows / columns (you can take me as a real layman here). Unless you change that, I don't understand _how_ you can reduce memory needs. Granted, I'm probably putting my foot in my mouth not understanding transformers.

4 comments

More ELI5 than the other comments. Considering the softmax network:

During quantization we find that values in the network vary from 0->5000, but 95% of values are <100. Quantizing this to 8bits would mean that our values would be in increments of about 20. Remembering that 95% of our values are below 100, we would only have about 5 discrete values for 95% of our values - so we would be losing a lot of "resolution" (entropy/information). For example (assuming rounding is used), an original value of 19 would be quantized to 20 and 30 would be quantized to 40. The original values differ by 11, but the quantized values differ by 20!

This is where exotic encodings come into play. We might try to use a logarithmic scheme, for example. This would result in higher value densities at lower values - but we would probably still waste bits and it would require more APU cycles.

Now switch to the softmax1 network:

The range of values is less important than the distribution - instead of 95% of the values falling in a small range, we would see the values more evenly spread out. Assuming that the range is now 105 (so the 5% outlying neurons from the softmax network are still >100), we would have 243 values to represent everything under 100. The same example with 19 and 30 would result in 19.27 and 30.34 respectively, a difference of 11.07 - which is very close to the unquantized difference of 11. We have retained more information in the quantized version of the network.

Information is lost either way, but what's important is how much information is lost.

The reason that the large values appear is because the heads attempt to "scream really loud" when they are certain that they are right. This is an emergent behavior due to softmax - it ironically sucks at paying attention to a few of the heads: it boosts the volume of the heads that are trying to abstain, and mutes the volume of the heads that are trying to vote.

> During quantization we find that values in the network vary from 0->5000, but 95% of values are <100. Quantizing this to 8bits would mean that our values would be in increments of about 20.

Instead of using an 8bit integer with even step size quantification, wouldn't they still use an 8bit float?

Possibly, it depends on the distribution of the vales. It would also make my examples far less straightforward :)

Either way you would still only have 256 discrete values.

No one quantizes blindly without accounting for data. If 95% of your values are in 0-100 you’ll probably do something like have 20 values for 0-100 and the remaining 12 for 101-5000. You don’t have to apply a uniform distribution and shouldn’t when your data is that concentrated.
Third paragraph.
If I'm following correctly, does this mean that with this change along with a model being quantized, we could see models that are 5% the size (on file system) and memory usage but almost identical in output?
The vales are selected were arbitrary. The size reduction will be 32bits/8bits - so it will be 4 times smaller.
It has to do with the precision of the values stored in those rows and columns. If they could be coerced into a narrower range (without losing information) then we could effectively store them each with 8 bits or something. The +1 prevents blowups when the denominator in its current form approaches 0, and without those blowups, then we can use less bits, in theory.
That is only true if the using the new softmax changes the dynamic range of the values. We are using floating point not fixed point. So if before our values went from 1 to 5000 and now they go from 0.0002 to 1 we still have the same dynamic range and so still need the same resolution.
The quantized versions are not floats but ints.
The activations (outputs) of one layer must be encoded in the same way as the weights of that layer as well as the weights of the next layer or the computation fails (unless you manage to write clever kernels for doing math at different levels of precision simultaneously, but even then you're introducing even more lossiness than just using a binary representation for those values).

Example: multiplying a bunch of float16s together gives you a float16. That is passed on to the next layer of float16s. Why should forcing the output of the first step to be float8 confer any advantage here? The only way I can see this argument working is if you make all the layers float8 too, and the reason you can do that is that the output of the first step can be faithfully represented as float8 because it doesn't ever blow up. If that's what the author is saying, it wasn't very clear.

You can reduce the number of bits per float (scalar).