Hacker News new | ask | show | jobs
by shoyer 2256 days ago
Right, you still need to write derivative rules by hand for the primitive operations of an auto-diff system. Automatic differentiation provides composition, it doesn't solve the root mathematical problem of differentiating operations at the lowest level.

So yes, if need a new primitive to add an efficient CUDA kernel, you will probably also have to write its derivative manually too. JAX has a few shortcuts that occasionally make this easier but fundamentally it has the same challenge as any auto-diff system.

2 comments

I still strongly disagree. Few of these hand written CUDA kernels outside of the frameworks are about implementing derivative rules, they're about eliminating the CUDA call overheads or avoiding the layered computational / memory inefficiencies that existing ML compilers have trouble handling.

Next to none of the frameworks are yet able to JIT you a performant RNN, yet RNNs only use very standard components[1]. OpenAI had a massive speed and memory usage boost for attention by implementing what amounts to a few standard primitives together[2].

There are massive gaps in the optimizations that existing ML compilers provide. The landscape is starting to get better but it's still filled with many pitholes.

[1]: https://twitter.com/stanfordnlp/status/1224106217192087552

[2]: https://openai.com/blog/sparse-transformer/

It depends what you define as primitive. I've had plenty of compositions of existing primitives for which the auto-derived backprop was orders of magnitude slower than a hand written one. I didn't need to write my own backprop, but I benefited tremendously from it. I don't think my experience is particularly rare.