|
Consider the function: relu(np.outer(x, y)) @ z.
This takes n^2 time and memory in the naive implementation.
But clearly, the memory could be reduced to O(n) with the right "fusing" of the operations.KANs are similar. This is the forward code for KANs: x = einsum("bi,oik->boik", x, w1) + b1
x = einsum("boik,oik->bo", relu(x), w2) + b2
This is the forward code for a Expansion / Inverse Bottleneck MLPs: x = einsum("bi,iok->bok", x, w1) + b1
x = einsum("bok,okp->bp", relu(x), w2) + b2
Both take nd^2 time, but Inverse Bottleneck only takes nd memory.
For KANs to match the memory usage, the two einsums must be fused.It's actually quite similar to flash-attention. |
Personally, I think this is fine in context. Context that it is a new formulation and the difficulty and non-obviousness of optimization. Shouldn't be expected that every researcher can recognize and solve all optimization problems.