Hacker News new | ask | show | jobs
by thomasahle 685 days ago
Consider the function:

    relu(np.outer(x, y)) @ z.
This takes n^2 time and memory in the naive implementation. But clearly, the memory could be reduced to O(n) with the right "fusing" of the operations.

KANs are similar. This is the forward code for KANs:

   x = einsum("bi,oik->boik", x, w1) + b1
   x = einsum("boik,oik->bo", relu(x), w2) + b2
This is the forward code for a Expansion / Inverse Bottleneck MLPs:

   x = einsum("bi,iok->bok", x, w1) + b1
   x = einsum("bok,okp->bp", relu(x), w2) + b2
Both take nd^2 time, but Inverse Bottleneck only takes nd memory. For KANs to match the memory usage, the two einsums must be fused.

It's actually quite similar to flash-attention.

1 comments

Which is to say, a big part is lack of optimization.

Personally, I think this is fine in context. Context that it is a new formulation and the difficulty and non-obviousness of optimization. Shouldn't be expected that every researcher can recognize and solve all optimization problems.