| Yeah, this is a good point. IIRC, I wasn't able to get the network to train very well at all with standard SGD. I don't think I thought to try Adam with β1 = 0, I will try it (& recompute brackets) if I get some time. If we have built up a momentum M, then the two orderings are: M' = M + εv1 θ' = θ + M' = θ + M + εv1 M'' = M' + εv2(θ') = M + εv1 + ε(v2 + (M + εv1)⋅∇v2) M' = M + εv2 θ' = θ + M' = θ + M + εv2 M'' = M' + εv1(θ') = M + εv2 + ε(v1 + (M + εv2)⋅∇v1) Then the resulting difference in momenta M'' is: ε^2*[v1, v2] + ε(M⋅∇)(v2 - v1) So there is an extra term which is not actually a Lie bracket itself. I think the bracket can still be informative on its own, but it's definitely no longer the sole component of what happens when order is swapped. One other inconsistency that is a little less bad is BatchNorm. Since it needs a whole batch to work, and we're just comparing individual examples, I computed the Lie brackets with the BatchNorm layers in eval mode, not train mode. I don't know if there is any relevance of this to Muon, even if so, it would likely be very messy to compute. |
However, what's harder to interpret is how this field transports with respect to θ, since the momentum vector and θ are themselves inextricably linked. If you somehow arrived at a different θ, then you'd have a different momentum. (On the gripping hand, the bracket is a construct of infinitesimals, maybe that doesn't matter.)