| I’ll do my best to answer here. > Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step? >> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead. > Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples. >> Yes but “consistent with the bad example” is nearly impossible. The gradient directions in late stages of training without noisy labels are already orthogonal or worse.. if you flip the label of any 2 samples and do a MB of size 1 on it they will all be negatively correlated to each other so you will practically always skip with GAF. However, in standard SGD you will ALWAYS average them in until you’ve completely memorized the noisy samples. > The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers? >> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating. > Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect? >> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea. > In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like? >> Good call out. Yes that wasn’t intuitive to me. You are correct that when Tau hits 2 it does converge to baseline as expected. But at 1.05 it actually does worse than baseline in the presence of 5% noise. So as you increase Tau above 1, which I never recommend doing, it starts to underperform baseline in the presence of noise then by 2 it matches. But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range. |
I agree with normal SGD, but with-momentum optimizers depend on some consistency of gradients between optimizer steps. On the other hand, with-momentum optimizers try to maximize the effective learning rate subject to that momentum, so it could go the other way as well.
> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.
> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.
Maybe the answer lies in asking what's optimized by averaging.
For learning, we're interested in the intractable problem of the gradient of parameters with respect to the whole distribution of data. In practice, we only compute the gradient of parameters with respect to samples drawn from the data distribution, leading to stochastic gradient descent. SGD-with-momentum makes the additional assumption that the steepest descent path has relatively low curvature, so the mean gradient of previous batches is still informative.
Overall, this approach is still optimal if you imagine that the computed sample gradients are corrupted with mean-zero Gaussian noise: averaging over many samples is the best way to eliminate that noise.
Your work identifies and rejects outlier gradients. In a very hand-wavy way, this is kind of like a median filter, and a median filter is great at rejecting shot noise. I speculate that this is why your technique is particularly good for your examples with corrupted labels, since that corruption process replaces single samples with something completely uninformative.
This is why I also wonder about soft thresholding. A soft threshold could be interpreted as an intermediate rather than binary belief about whether a sample is informative, or it could be interpreted as belief that a sample has more than zero but less than the typical amount of information.
> But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.
If it would be easy to add (that is, if you still have the data on hand), might I suggest adding a subpanel to figure 7 noting the fraction of minibatches that are accepted/rejected with each threshold? If you're hypothetically rejecting 80% of minibatches at the optimum threshold, it'd hint that your method is finding the golden kernel of most representative data to learn from; in contrast if you're hypothetically rejecting just a couple of percent then it'd hint that your method is more narrowly finding the corrupted samples. Either range (or anything in between) would be interesting.