|
|
|
|
|
by shoyer
2256 days ago
|
|
Right, you still need to write derivative rules by hand for the primitive operations of an auto-diff system. Automatic differentiation provides composition, it doesn't solve the root mathematical problem of differentiating operations at the lowest level. So yes, if need a new primitive to add an efficient CUDA kernel, you will probably also have to write its derivative manually too. JAX has a few shortcuts that occasionally make this easier but fundamentally it has the same challenge as any auto-diff system. |
|
Next to none of the frameworks are yet able to JIT you a performant RNN, yet RNNs only use very standard components[1]. OpenAI had a massive speed and memory usage boost for attention by implementing what amounts to a few standard primitives together[2].
There are massive gaps in the optimizations that existing ML compilers provide. The landscape is starting to get better but it's still filled with many pitholes.
[1]: https://twitter.com/stanfordnlp/status/1224106217192087552
[2]: https://openai.com/blog/sparse-transformer/