Hacker News new | ask | show | jobs
by brrrrrm 1327 days ago
> It compiles a custom kernel for every operation, allowing extreme shape specialization.

This doesn't matter. Just look at the performance achieved by CuDNN kernels (which back PyTorch), they're dynamically shaped and hit near peak. For dense linear algebra at the size of modern neural networks, optimizing for the loop bound condition won't help much.

> All tensors are lazy, so it can aggressively fuse operations.

This matters. PyTorch teams are trying to implement that now (they have LazyTensor, AITemplate, TorchDynamo), but I'm not sure of the status (it's been tried repeatedly).

> The backend is 10x+ simpler, meaning optimizing one kernel makes everything fast.

The first part of that sentence matters, the second part doesn't. Kernels are already fast and their reuse outside of being fused into each other (which you need a full linear algebra compiler to do) isn't very high. If you make sum fast, you have not made matrix multiplication fast even though MM has a sum in it. It just isn't that easy to compose operations and still hit 80+% of hardware efficiency.

But it is easier to iterate fast and build a seamless lazy compiler if your backend is simple. You can pattern match more easily and ensure you handle edge cases without insanely complicated things like alias analysis (which PyTorch has to do).

3 comments

> they're dynamically shaped and hit near peak

While this is true for most common GEMM looking ops, if you tread off the beaten path things get slow (odd channel sizes, batch sizes, etc...). Right now in PyTorch, GroupNorm is 2x slower than BatchNorm. There's no fundamental reason, just that the kernels loop over axes in a less than ideal order. Dynamic recompilation allows you to change the loop order too, not just deal with boundary conditions.

> tread off the beaten path things get slow

Yea, makes sense. I think there's something to be said for dynamic compilation solving this problem more elegantly than providing tons of hand-tuned kernels (PyTorch is 890MB lmao https://pypi.org/project/torch/#files), but I don't think it's a strict reason for a performance win.

> change the loop order too

Memory layout as well! I'm 100% for dynamic compilation, but I'm claiming that it really finds its stride when you fuse things.

Agreed. For anything at all common, most of the gains will be from fusion, the rest is just free. PyTorch also uses tons of GPU memory after only initializing, I wonder if it's copying all the kernels in?
Jax preallocates 90% of available GPU memory when first operation is run to minimize allocation overhead. Can PyTorch grab that VRAM for a similar reason?
Yes PyTorch uses what they call a caching memory allocator[0], basically seems like are allocating a very chunk of GPU memory and implementing a heap with it. If needed they expose some knobs and functions to allow you to control it and observe the memory usage.

[0]: https://pytorch.org/docs/stable/notes/cuda.html#memory-manag...

> Right now in PyTorch, GroupNorm is 2x slower than BatchNorm

How did you benchmark this? I think there are like 3 or 4 different GN implementations in PyTorch..

Whole net performance at comma, when we switch from BatchNorm to GroupNorm it adds 70ms to the training step time, and it's -70ms for no norm. We also wrote a custom AllNorm that's like 10% slower than BatchNorm (and I put several hours into trying to optimize it). Obviously not indicative of everyone's experience, but my point is BatchNorm is hyperoptimized and others, which are pretty much the same thing, aren't.
Thanks, that's certainly helpful anecdotal evidence.. yeah it seems like there should be an "AllNorm" implementation that covers all cases and is just fast. I was wondering because I'm currently looking at math_group_norm, which was ported from PyTorch/XLA and it results in a really weird decomposition that I'm astonished works at all. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen...

I'm also wondering if the handcoded backward passes are actually "numerically correct", because e.g. epsilon doesn't appear in it at all. Someone worked out the gradients manually for BN here: https://web.archive.org/web/20180826123459/http://cthorey.gi...

You can clearly see epsilon appearing in the output. And of course there's the whole training vs. eval mode thing with BN which GN doesn't have.

In any case, thanks again.

What does it mean to "fuse operations"?
avoiding writes to memory and reducing the number of loops (although not FLOPs)

    for j in range(10):
      c[j] = a[j] + b[j]
    for j in range(10):
      d[j] = c[j] * 2
becomes

    for j in range(10):
      d[j] = (a[j] + b[j]) * 2
Or, better, identifying that the machine has a primitive that is better than doing each op individually. For example, a multiply-accumulate instruction vs a multiply and separate accumulate. The source code still says "a*b+c", the compiler is just expected to infer the MAC instruction.
Yep! This is an assumed optimization when it comes to modern linear algebra compilers. New primitives go way beyond FMAs: full matrix multiplies on nvidia/Intel and outer product accumulates on Apple silicon. It’s also expected that these are used nearly optimally (or you’ve got a bug).
I am extremely familiar with how far these primitives go, ha. I develop kernels professionally for AWS ML accelerators.
Any more writing on laziness in frameworks? I'm trying to implement it myself.
The only thing I'd recommend is exposing "eval()" or something to let users tell you when they want you to evaluate things. It'll save a ton of time when it comes to hot-fixing performance and memory use issues. It's really hard to determine when to evaluate, and although it's a fun problem to figure out, it's nice to have an escape hatch for users to just tell you. (Flashlight has explored this and written about it here: https://fl.readthedocs.io/en/latest/debugging.html?highlight...)

If you're interested, I've looked into symbolic laziness, which allows you to infer correct input sizes even when the constraints happen later. Can be useful for errors. https://dev-discuss.pytorch.org/t/loop-tools-lazy-frontend-e...