| Wait a second, they define the induced vector field (and consequently Lie bracket) in terms of batch-size 1 SGD: > In particular, if x is a training example and L(x) is the per-example loss for the training example x, then this vector field is: v^(x)(θ) = -∇_θ L(x). In other words, for a specific training example, the arrows of the resulting vector field point in the direction that the parameters should be updated. but for the MXResNet example: > The optimizer is Adam, with the following parameters: lr = 5e-3, betas = (0.8, 0.999) This changes the direction of the updates, such that I'm not completely sure the intuitive equivalence holds. If it were just SGD with momentum, then the measured update directions would be a combination of the momentum vector and v1/v2, so {M + v1, M + v2} = {v1, M} + {M, v2} + {v1, v2}. The Lie bracket is no longer "just" a function of the model parameters and the training examples; it's now inherently path dependent. For Adam, the parameter-wise normalization by the second norm will also slightly change the directions of the updates in a nonlinear way (thanks to the β2 term). The interpretation is also strained with fancier optimizers like Muon; this uses both momentum and (approximate) SVD normalization, so I'm really not sure what to expect. |
If we have built up a momentum M, then the two orderings are:
M' = M + εv1
θ' = θ + M' = θ + M + εv1
M'' = M' + εv2(θ') = M + εv1 + ε(v2 + (M + εv1)⋅∇v2)
M' = M + εv2
θ' = θ + M' = θ + M + εv2
M'' = M' + εv1(θ') = M + εv2 + ε(v1 + (M + εv2)⋅∇v1)
Then the resulting difference in momenta M'' is:
ε^2*[v1, v2] + ε(M⋅∇)(v2 - v1)
So there is an extra term which is not actually a Lie bracket itself. I think the bracket can still be informative on its own, but it's definitely no longer the sole component of what happens when order is swapped.
One other inconsistency that is a little less bad is BatchNorm. Since it needs a whole batch to work, and we're just comparing individual examples, I computed the Lie brackets with the BatchNorm layers in eval mode, not train mode.
I don't know if there is any relevance of this to Muon, even if so, it would likely be very messy to compute.