Hacker News new | ask | show | jobs
by guest1539 3816 days ago
What part doesn't scale?
1 comments

I'm a little late to the party, but hopefully this'll explain.

Basically, you have to be careful about what it means to scale or not scale. If all you want is a derivative with respect to a single variable, forward mode scales just fine, great in fact. However, if you want the gradient, or the derivative with respect to every variable, then the forward mode does not scale well at all with respect to the number of variables. Specifically, assume we have m variables. In order to calculate the derivative of an expression with respect to 1 variable is 2 times the cost of a function evaluation, 2 * eval. In order to see this, it's easiest to note that we don't need an expression tree for forward mode AD like the article uses. Really, we can get away with just a tuple that contains the function evaluation as the first element and the partial derivative as the second element. Then, all of the rules are basically the same as the article, but we're always doing one operation on the first element, whatever the function is, and a different operation on the second element for the partial derivative. This is twice to work, so 2 * eval. Since we have m variables, this becomes 2 * m * eval. And, yes, memory layouts, fewer optimizations for algebraic data types compared to floats, etc. mean that it's actually slower, but, honestly, it's pretty fast.

The reverse mode is different because it turns out that it can calculate the entire gradient, or all m partial derivatives, with 4 * eval cost. Note, this is independent of the number of variables. Proving this is a pain, so I can't give a good explanation here. Realistically, source code transformation tools perform around 10-20 * eval. Operator overloading tools perform around 20-30 * eval, so it's slower in practice, but pretty damn good.

Now, unlike the forward mode, where we really only need a tuple to carry information, the reverse mode does require an expression tree. In order to understand why, it helps to note that the forward mode is really a directional (Gatteaux) derivative and the reverse mode is the total (Frechet) derivative. This affects how the chain rule manifests. Specifically, the forward mode repeatedly applies two rules

(f o g)'(x) dx = f'(g(x)) g'(x) dx

(f o (g,h))'(x) dx = f'(g(x),h(x)) (g'(x)dx,h'(x)dx)

Basically, in the function evaluation, we do some operation g before f. In order to figure out the derivative, we also do the g derivative operation before the f derivative operation. The first rule is for unary operations like negation and the the second rule is for binary operations like addition. Anyway, the reverse mode takes the Hilbert adjoint of this. Specifically:

(f o g)'(x)^* = g'(x)^* f'(g(x))^*

(f o (g,h))'(x)^* = [g'(x)^* h'(x)^* ]f'(g(x),h(x))^*

We care about the adjoint because of a trick from the Riesz representation theorem. Specifically,

f'(x)dx =

(f'(x)dx)1 =

<f'(x)dx,1> =

<dx,f'(x)^* 1> =

<dx,grad f(x)>

where <.,.> denotes the inner product. Anyway, basically the gradient of f is the adjoint of the total derivative of f applied to 1. Therefore, if we knew the adjoint of a computation applied to 1, we'd get the gradient. In other words, we can rewrite the chain rule above as

grad (f o g)(x) = g'(x)^* grad f(g(x))

grad (f o (g,h))(x) = [g'(x)^* h'(x)^* ]grad f(g(x),h(x))

That's the core of reverse mode AD. Note, many, if note most descriptions of reverse mode AD talk about doing the chain rule in reverse and then they add dual variables, etc. That may be a description that's helpful for some, but not for me. In truth, it's just a bunch of adjoints applied two one and knowing the Riesz representation trick.

Now, the reverse mode AD does require an expression tree to be kept. The reason for this is that the computation about did g before f. However, if we look at the chain rule we have

grad (f o g)(x) = g'(x)^* grad f(g(x))

This means that in order to calculate the gradient of the composition, we need to know the gradient of f first even though we did the evaluation of g first. However, we need to know the evaluation of g in order to calculate the gradient of f. The way we resolve this is that we evaluate the functions in order, but keep an expression tree of what we did. This gives all of the g(x), f(g(x)), etc. Then, we run over that expression tree backward to calculate all of the gradients. Because we run over the expression tree backwards, we call this the reverse mode.

How we run over the expression tree backwards is important and tricky to do right. The way that we can sort of see that we can do everything in 4 * eval cost is that the trick is not to create multiple vectors to store the gradient when running over the tree, but to have 1 vector and to update this vector with the new derivative information when required. Basically, we're just inserting information in the right spots, which can be done efficiently. In practice, storing the expression tree in memory can be really expensive. For example, imagine a for-loop that had 10 billion loops. That's a really long expression tree to hold in memory. Now, source code transformation tools are really clever and don't actually store all of those expressions in memory, but just run back the for loop, which is why they're more efficient. Operator overloading techniques (algebraic data types) can technically optimize this as well by doing some interesting caching techniques. However, the overall idea is that it can be expensive and there are lots of ways to do this wrong, but also lots of places to do things right and be creative.

As aside to a comment left above, back propagation is indeed just reverse mode AD combined with a nonglobally convergent version of steepest descent. I've never seen a paper that worked this out, but it's something that's known within the AD community. Someone, someday, should really write that down.

Anyway, that's probably a much too long response to your simple question. In short, forward mode doesn't scale when calculating gradients because the cost is 2 * m * eval whereas the reverse mode can do it in 4 * eval. For a single variable, or an entire directional derivative, the forward mode scales fine and in fact works better than the reverse mode for this case.

Edit: This formatting is killing me. Hopefully, it all looks fine now.