|
|
|
|
|
by epistasis
262 days ago
|
|
When thinking about block matrix multiplication, it's always a fun time to revisit Strassen's algorithm, which is less than O(n^3). Normal block multiplication works like: [ A11 A12 ] [ B11 B12 ] = [ A11*B11 + A12*B21 A11*B12 + A12*B22 ] = [ C11 C12 ]
[ A21 A22 ] [ B21 B22 ] [ A21*B11 + A22*B21 A21*B12 + A22*B22 ] = [ C21 C22 ]
Which takes 8 matrix multiplications on the sub blocks. But by cleverly defining only 7 different matrix multiplications on top of block additions and subtractions, like: M3 = A11 * (B12 - B22)
You can make the C blocks out of just additions and subtractions of the 7 different matrix multiplications.https://en.wikipedia.org/wiki/Strassen_algorithm As far as I know this is not useful in the major GPU libraries for saving bandwidth, but I have never bothered to spend the time to figure out why. It must have something to do with the ratio of bandwidth to FLOPs, which is way past my knowledge of GPUs. |
|