Yoga is purely functional, so it's possible to use backpropagation to compute it efficiently.
It actually compiles 2 versions of each function: call them foo and foo.grad. foo.grad takes the same arguments as foo, and also a gradient for each output argument. It then computes gradients for each input argument.
The algorithm is simple: traverse the expression tree in the usual order you'd use for emitting code, and remember the order. Then traverse in reverse order, propagating gradients as you go.
The tedious bit is writing the gradients for every built-in op. For an operation like + it's simple: each argument gets the same gradient as the result:
Have you considered using an off the shelf differentiable programming implementation? Or are the requirements for real-time applications too demanding for existing software?
It actually compiles 2 versions of each function: call them foo and foo.grad. foo.grad takes the same arguments as foo, and also a gradient for each output argument. It then computes gradients for each input argument.
The algorithm is simple: traverse the expression tree in the usual order you'd use for emitting code, and remember the order. Then traverse in reverse order, propagating gradients as you go.
The tedious bit is writing the gradients for every built-in op. For an operation like + it's simple: each argument gets the same gradient as the result:
For something like divide it's a bit grosser: More at https://gitlab.com/umbrellaresearch/yoga2/-/blob/master/jit/... if you're curious.