Hacker News new | ask | show | jobs
by danielhanchen 829 days ago
On how long, finetuning is influenced by your dataset size (more = slower), sequence length since attention is O(N^2), data movement etc and most important is how many steps you want to take. For QLoRA, some runs can do a few hundred steps which can complete in minutes to 1 hour. Too many can overfit. So being able to fit it on consumer GPUs can be very cost effective.

On the 1.58bit paper, from what I understand, this requires a total retraining from scratch. Hopefully the researchers will open source their weights :)

On the technicals, weights are encoded in (-1, 0, 1), whilst QLoRA uses a 4bit dynamic mapping of 16 numbers. The only change required would be the torch.matmul(X, W) step, where it'll be torch.bitlinear_matmul(X, W). Before with QLoRA, one has to do torch.matmul(X, dequantize(W)). So one has to implement torch.bitlinear_matmul. The backward is torch.bitlinear_matmul(dY, W.T).

1 comments

What's the magic in 1.58bit vs. 4 bit that it makes it so much more efficient (claimed)?
From what I understand, using (-1, 0, 1) removes multiplications in GPUs. Ie assume you have a weight matrix and multiply it by some activations

                   [-1, 0,  1]

                   [0,  1, -1]

    [10, 20, 30] x [1,  1,  0]
Instead of doing 10(-1) + 20(0) + 30(1) + 10(0) + ..., since we know beforehand the weights are simply (-1, 0, 1), we easily flip the sign and do addition, or force the hardware to do addition ie if (-1) do subtraction. If (0) do addition. If (1) do addition.

Floating point multiplication does addition of the exponents and multiplying of the mantissa. So just simplifying:

Float16 has E=5, M=10. Ie around 5 + 10^2 space needed = 105.

Bfloat16 has E=8, M=7. So 8 + 7^2 = 57 space.

Float8(143) E=4, M=3. So 4 + 3^2 = 13 space.

1.58(16bit) E=5, M=10. Addition only, so shift E say 5 + 10 addition = 15.

1.58(8bit) E=4, M=3. Addition only, so shift E say 4 + 3 addition = 7.

Obviously I'm simplifying, but with only additions, 1.58 uses say 7 space, whilst FP8 uses 13 space, so in theory 2x more transistors can be crammed, ie 2x more FLOPs than FP8.

Really simple explanation is that for inference, feed forward networks are threshold circuits and by their nature ANNs are binary output, outputting true and false (same as being a threshold circuit)

So if you train your models with that in mind you're weighs can be reduced to -1,0,1 reducing the space complexity.

I don't think the costs in expressiveness are captured quite yet, but as perplexity doesn't care about correctness, if that is the metric that is important for you it will probably reduce memory requirements for inference.

also just to add, I think the 1.58 bit is mostly faster for inference because training still had to multiply a lot of floating point gradients by integer activations, hold floating point weights/gradients for round, and deal with norms and stuff. could be wrong about that though