|
|
|
|
|
by dan-robertson
1385 days ago
|
|
Maybe I’m just slow but it wasn’t immediately obvious to me how to use this for matrix multiplication. Let me now try to explain. Suppose we have some matrices we would like to multiply, a_ok and b_ij (and let’s say their sizes line up with the hardware because I think those details aren’t so relevant). Their product is c_ik = a_ij b_jk = sum(a_ij * b_jk for all j).
The hardware lets us cheaply compute and accumulate an outer product (see picture in OP): r_ij = r’_ij + p_i * q_j
Now start with r = 0 and accumulate: r_ik = a_i1 * b1k
+ a_i2 * b2k
+ ...
+ a_in * b_nk
= c_ik
Each row corresponds to one AMX op on all the cells of the matrix.Writing it out like this it seems quite straightforward. I think I was caught up on thinking about the per-cell computation too much. When computing based on cells in the output, you take a row from the left hand side and dot it with a column from the right hand side (nn dot products). Here, we take a column* from the left hand side and a row from the left hand side and outer product them (n outer products) and add up the result. Perhaps this is partly a victory for this kind of symbolic index notation. I think this would all be much less obvious if I wrote it all out as a sum of outer products with eg the tensor product symbol. |
|