Hacker News new | ask | show | jobs
by PartiallyTyped 1529 days ago
The nice thing about differentiable programming is that we can use all sorts of different optimizers compared to gradient descent that can offer quadratic convergence instead of linear!
2 comments

Yes exactly! This is huge. Hessian optimization is really easy with JAX, haven't tried it in Julia though
Here's Hessian-Free Newton-Krylov on neural ODEs with Julia: https://diffeqflux.sciml.ai/dev/examples/second_order_adjoin... . It's just standard tutorial stuff at this point.
And very fast given that you compile the procedure! I am considering writing an article on this and posting it here because I have seen enormous improvements over non jitted code, and that excluded jax.vmap.
There's a comparison of JAX with PyTorch for Hessian calculation here!

https://www.assemblyai.com/blog/why-you-should-or-shouldnt-b...

Would definitely be interested in an article like that if you decide to write it

Why can't we use this quadratic convergence in deep learning?
Well, quadratic convergence usually requires the Hessian, or an approximation of it, and that's difficult to get in deep learning due to memory constrains, and difficulty computing second order derivatives.

Computing the derivatives is not very difficult with e.g. Jax, but ... you get back to the memory issue. The Hessian is a square matrix, so in Deep Learning, if we have a million of parameters, then the Hessian is a 1 trillion square matrix...

Not only does it have 1 trillion elements, you also have to invert it!
Indeed! BFGS (and derivatives) approximate the inverse but they have other issues that make them prohibitively expensive.
To add, one could think of schemes like "momentum" and cousins as attempts to estimate something in the spirit of the inverse Hessian using various hacks/heuristics.