Hacker News new | ask | show | jobs
by ismailmaj 92 days ago
You drop the memory throughput requirements because of the packed representation of bits so an FMA can become the bottleneck, and you bypass the problem of needing to upscale the bits to whatever FP the FMA instruction needs.

typically for 1-bit matmul, you can get away with xors and pop_counts which should have a better throughput profile than FMA when taking into account the SIMD nature of the inputs/outputs.

1 comments

yes but this is not 1 bit matmul, it's 1.58 bits with expensive unpacking
The title and the repo uses 1-bit when it means 1.58 bits tertiary values, it doesn't change any of my arguments (still xors and pop_counts).
How do you do ternary matmul with popcnt on 1.58 bit packed data?
Assuming 2 bit per values (first bit is sign and second bit is value).

actv = A[_:1] & B[_:1]

sign = A[_:0] ^ B[_:0]

dot = pop_count(actv & !sign) - pop_count(actv & sign)

It can probably be made more efficient by taking a column-first format.

Since we are in CPU land, we mostly deal with dot products that match the cache size, I don't assume we have a tiled matmul instruction which is unlikely to support this weird 1-bit format.

Haven't looked closely, but on modern x86 CPUs it might be possible to do much better with the gf2affineqb instructions, which let us do 8x8 bit matrix multiplications efficiently. Not sure how you'd handle the 2-bit part, of course.
This is 11 bit ops and a subtract, which I assume is ~11 clocks, while you can just do:

l1 = dot(A[:11000000],B[:11000000]) l2 = dot(A[:00110000],B[:00110000]) l3 = dot(A[:00001100],B[:00001100]) l4 = dot(A[:00000011],B[:00000011])

result = l1 + l2 * 4 + l3 * 16 + l4 * 64

which is 8 bit ops and 4x8 bit dots, which is likely 8 clocks with less serial dependence