For those of us using the 2D NATTEN kernel from their library along with torch.compile, is this faster? Especially given all their tricks (e.g., the non-deterministic KV-parallelism)
In my (very amateurish) testing, I think the performance seemed pretty comparable (for non-dilated natten). I need to do some proper benchmarking though!