Hacker News new | ask | show | jobs
by ffriend 3132 days ago
Some time ago I implemented a library [1] similar to this tool. The tricky part is that derivatives quickly exceed 2 dimensions, e.g. derivative of a vector output w.r.t. input matrix is a 3D tensor (e.g. if `y = f(X)`, you need to find derivative of each `y[i]` w.r.t. each `X[m,n]`), and we don't have a notation for it. Also, often such tensors are very sparse (e.g. for element-wise `log()` derivative is a matrix where only the main diagonal has non-zero values corresponding to derivatives `dy[i]/dx[i]` where `y = log(x)`).

The way I dealt with it is to first translate vectorized expression to so-called Einstein notation [2] - indexed expression with implicit sums over repeated indices. E.g. matrix product `Z = X * Y` may be written in it as:

    Z[i,j] = X[i,k] * Y[k,j]   # implicitly sum over k
 
It worked pretty well and I was able to get results in Einstein notation for element-wise functions, matrix multiplication and even convolutions.

Unfortunately, the only way to calculate such expressions efficiently is to convert them back to vectorized notation, and it's not always possible (e.g. because of sparse structure) and very error-prone.

The good news is that if the result of the whole expression is a scalar, all the derivatives will have the same number of dimensions as corresponding inputs. E.g. in:

    y = sum(W * X + b)
if `W` is a matrix, then `dy/dW` is also a matrix (without sum it would be a 3D tensor). This is the reason why backpropogation algorithm (and symbolic/automatic differentiation in general) in machine learning works. So finally I ended up with a another library [3], which can only deal with scalar outputs, but is much more stable.

Theoretical description of the method for the first library can be found in [4] (page 1338-1343, caution - 76M) while the set of rule I've derived is in [5].

[1]: https://github.com/dfdx/XDiff.jl

[2]: https://en.wikipedia.org/wiki/Einstein_notation

[3]: https://github.com/dfdx/XGrad.jl

[4]: http://docs.mipro-proceedings.com/proceedings/mipro_2017_pro...

[5]: https://github.com/dfdx/XDiff.jl/blob/master/src/trules.jl