|
Sure, I can give you detailed answers: 1- The answer is still ~1.7GB. You only need meta-data of a single nn.Linear at a time. There are 32x(4+3) = 224 layers quantized, so you need an additional (3GB - 1.7GB)/224 = 1.3GB/224 ~ 5.8MB, which is negligible. 2- As the table says, that's the forward pass. (Batch-size=1, context-size=1024). Forward pass means there's no caching and no decoding logic. The actual model generation speed should be much faster with caching + decoding logics like speculative decoding, and using VLLM instead of HF. And even with all of that, a much larger model like Mixtral with the same group-size of 8 offloaded to the CPU works quite well on a 4090. You mean why it's faster than Quip# despite being all on-device? Because dequantization with HQQ is a simple linear operation. It's not even using a fused kernel, only the dequantization part is done on CUDA. 3- LoRA absorbs -zero x shift, not the shift, the shift is still there, including the BitNet/1.58 work. As the paragraph explains, the math is ignoring reshaping to make the math simple and easy to read. Let's say you have a matrix of 4096x4096, with grouping done channel-wise (but no reshaping), the -zero x shift part is a rank-1 matrix (4096x1 .dot 1x4096), the lora data will be (4096xr .dot rx4096), you can merge them exactly into (4096x(r+1) .dot (r+1)x4096). The point of that paragraph is to show two things:
- Compared to the BitNet formulation, the additional zero-point (which is necessary to get good quantization results on pre-trained models with minimum calibration), has a negligible overhead.
-More importantly, it explains how we even got the idea of adding low-rank adapters: it's not because LoRA is popular, it's because the zero-point alone results in a rank-1 matrix error which is not enough to express the quantization error. As the rank tends to min(num_rows, num_cols), the error goes down. So if we increase the rank by r via low-rank adapters, we would expect better results. Now, if we include the reshape with a lower-group size than the num_rows, the -zero x shift part is a rank-n matrix (4096xn dot nx4096), but it's not possible to properly estimate the rank n because that would highly depend on the nature of the weights matrix, but in the end, the LoRA part will be (4096x(n+r) .dot (n+r)x4096). We only use a lora rank of 8 for MLPs which are the larger matrices, so even if you double or even 4x to let's say n+r=32, it's still just 1/128=0.78% of the original matrix. Merging -zero x scale with the low-rank adapters or not doesn't matter much, that would highly depending on which fused kernel implementation performs the best. |
Why not do the same optimization for layer weights themselves?