|
|
|
|
|
by Smerity
2255 days ago
|
|
For high performance CUDA kernels people still need to write derivatives by hand. I know this as for my own research, and for many production systems, I'd still need to write it myself. Many of my architectures wouldn't have been possible without writing the CUDA myself (Quasi-Recurrent Neural Network[1]) or using optimized hand written black boxes (cuDNN RNN). The lack of open optimized hand written CUDA kernels has actually been an impediment to progress in the field. Automatic differentiation allows for great flexibility and composability but the performance is still far from good, even with the various JITs available. Jax seems to be one of the most flexible and optimized for many use cases for now however. [1]: https://github.com/salesforce/pytorch-qrnn |
|
So yes, if need a new primitive to add an efficient CUDA kernel, you will probably also have to write its derivative manually too. JAX has a few shortcuts that occasionally make this easier but fundamentally it has the same challenge as any auto-diff system.