|
|
|
|
|
by augment_me
31 days ago
|
|
TLDR: Authors realize that global row-wise dependent functions like RMSNorm/LayerNorm have baked-in scales that are commutative in certain setups, so they can be moved out after a subsequent projection and be partially aggregated on tiles of rows. So ((W1 @ gamma * globally_computed_scale) * W2 can be written as (W1 @ gamma * W2) * globally_computed_scale as long as we have row-only interactions for the scale. This was usually not done before because left-to-right graph compilers like torch.compile can't assume that a global row-wise reduction between GEMMs can be commutative. |
|