|
Yoga is purely functional, so it's possible to use backpropagation to compute it efficiently. It actually compiles 2 versions of each function: call them foo and foo.grad. foo.grad takes the same arguments as foo, and also a gradient for each output argument. It then computes gradients for each input argument. The algorithm is simple: traverse the expression tree in the usual order you'd use for emitting code, and remember the order. Then traverse in reverse order, propagating gradients as you go. The tedious bit is writing the gradients for every built-in op. For an operation like + it's simple: each argument gets the same gradient as the result: void ExprAdd::backprop(YogaContext const &ctx, GradientExprSet &grads, ExprNode *g)
{
grads.addGradient(ctx, args[0], g);
grads.addGradient(ctx, args[1], g);
}
For something like divide it's a bit grosser: void ExprDiv::backprop(YogaContext const &ctx, GradientExprSet &grads, ExprNode *g)
{
// https://en.wikipedia.org/wiki/Quotient_rule
grads.addGradient(ctx, args[0], ctx.mkExpr<ExprDiv>(g, args[1]));
grads.addGradient(ctx, args[1], ctx.mkExpr<ExprNeg>(
ctx.mkExpr<ExprMul>(g,
ctx.mkExpr<ExprDiv>(
args[0],
ctx.mkExpr<ExprPow>(
args[1],
ctx.mkExpr<ExprConstDouble>(2.0))))));
}
More at https://gitlab.com/umbrellaresearch/yoga2/-/blob/master/jit/... if you're curious. |