| > It compiles a custom kernel for every operation, allowing extreme shape specialization. This doesn't matter. Just look at the performance achieved by CuDNN kernels (which back PyTorch), they're dynamically shaped and hit near peak. For dense linear algebra at the size of modern neural networks, optimizing for the loop bound condition won't help much. > All tensors are lazy, so it can aggressively fuse operations. This matters. PyTorch teams are trying to implement that now (they have LazyTensor, AITemplate, TorchDynamo), but I'm not sure of the status (it's been tried repeatedly). > The backend is 10x+ simpler, meaning optimizing one kernel makes everything fast. The first part of that sentence matters, the second part doesn't. Kernels are already fast and their reuse outside of being fused into each other (which you need a full linear algebra compiler to do) isn't very high. If you make sum fast, you have not made matrix multiplication fast even though MM has a sum in it. It just isn't that easy to compose operations and still hit 80+% of hardware efficiency. But it is easier to iterate fast and build a seamless lazy compiler if your backend is simple. You can pattern match more easily and ensure you handle edge cases without insanely complicated things like alias analysis (which PyTorch has to do). |
While this is true for most common GEMM looking ops, if you tread off the beaten path things get slow (odd channel sizes, batch sizes, etc...). Right now in PyTorch, GroupNorm is 2x slower than BatchNorm. There's no fundamental reason, just that the kernels loop over axes in a less than ideal order. Dynamic recompilation allows you to change the loop order too, not just deal with boundary conditions.