Hacker News new | ask | show | jobs
by dkislyuk 47 days ago
Something that really helped me grasp the foundational relevance of the softmax is to justify from first principles why e^x shows up in the preferred mapping function in the numerator (1). The stated problem of mapping raw inputs/scores/logits to a probability distribution can be solved by a bunch of arbitrary functions, and the usual justification given for a softmax is "it has nice derivatives" which is empirically useful but not satisfying.

The sketch of the justification is something like this. We first need a function that maps from (-inf, inf) to a unique positive value, and then we need to normalize the resulting values. Setting aside the normalizing step, we imagine a f(x) that needs to fit the following properties:

1. It should be strictly positive, so that we can normalize it into a (0, 1) probability.

2. It should preserve the relative ordering of the logits to allow them to be interpreted as scores. Thus $f(x)$ should be monotonically increasing.

3. It should be continuous and differentiable everywhere, since we are interested in learning through this function via backpropagation.

4. It should have shift-invariance with respect to the input, as we don't want the model to have to learn some preferred logit-space where there is a stronger learning signal. For example, applying softmax on the values `(-1, 1, 3, 5)` would yield the same result as applying it to `(9, 11, 13, 15)`. This property can also be restated as a "scale invariance of probability ratios", where the ratio between $f(x)$ and $f(x+c)$ for a given $c$ is a constant. One useful interpretation of this property is that the learning domain or "gradient-learning surface" is stable, and high-magnitude initializations won't impede the learning process.

Taken at face value, these properties uniquely define e^x. The last property is actually pretty debatable, because in the context of machine learning, we actually do have a "preferred logit-space", namely closer to zero, for numerical stability. But there are other ways to enforce this in a post-hoc manner (e.g. weight initialization, normalization layers, etc.)

Another property that is uniquely justifies e^x and thus softmax is IIA (independence of irrelevant alternatives), which states that the odds for two classes, p_i / p_j, only depend on the logits/inputs for i and j, and an irrelevant class k has no impact. For example, for Softmax([5, 7, 1]) and Softmax([5, 7, 10]), the resulting odds for the first two values (p_i/p_j) should be the same from both distributions, regardless of the third value.

Finally, if the "desired properties" approach is not satisfying, a more theoretical route for justifying the form of the softmax uses the framework of maximum entropy (E. T. Jaynes published this in 1957 to justify the Boltzmann distribution).

TL;DR, softmax is not a the only solution to mapping function of unnormalized values to a probability distribution, but it can be justified through axiomatic properties.

(1) one could say that the exponential shows up from the Boltzmann distribution, but then the same question applies.

2 comments

The reason for exp(x) is that its derivative is exp(x), which makes it possible to express the gradient of s(x) in terms of s(x), or both in terms of exp(x). This simplifies the computation of backward pass.
I agree that "it has nice derivatives" is a great empirical reason to use a specific function in ML, but it doesn't sufficiently prove that it's the best function to use. And even if a derivative term looks more complex, that doesn't necessarily imply that it is more computationally expensive to compute, so that can't be the only criteria to select a function.

Luckily, there are more axiomatic reasons for why softmax is the preferred way to map inputs to a probability distribution.

> The stated problem of mapping raw inputs/scores/logits to a probability distribution can be solved by a bunch of arbitrary functions, and the usual justification given for a softmax is "it has nice derivatives" which is empirically useful but not satisfying.

Often there isn't any more to it than that. For example, the entire justification for least-squares error measurement is that it has convenient derivatives.

The central limit theorem is an extremely powerful justification. That doesn't mean it's considered whenever it's used, but it absolutely can be strongly justified (to the degree that other error measurements are only needed in relatively small samples of the feature space where errors will not yet converge to Gaussian)