Hacker News new | ask | show | jobs
by ddragon 2425 days ago
I think it's easier to visualize considering how pytorch (or tensorflow eager) does it. Pytorch supports loops, conditionals and other native python non differentiable operators, but what happens is that for each forward/backward pass pair you run the graph building function again. Only the functions that are overloaded by the library with the execution graph ("tape") creation steps affect what's in each final tape (and they are all differentiable), while the non differentiable native parts (like loops) only define the combination of the former differentiable pieces. You can even have a function that asks for user input every time in order to build the tape (a non pure function), as long as said graph is always differentiable at each run then auto-differentiation still works. Also, unlike symbolic differentiation, you can have a surrogate gradient that approximates the behavior of the function well enough when it's not differentiable (for example max pooling for convolutional network).

The original tensorflow work the same, but instead of running the graph every time, it embeds the non differentiable control mechanism in the code of the graph, which can more efficiently (without needing the host language to build a new one every time) create the correct differentiable tape for each run based on it's input. And source-to-source differentiation work exactly the same way, except instead of having to use a DSL (like the tensorflow graph API) and compile it, it simply uses the host language and compiler directly (so you don't need effectively two languages). Which is the case of Julia's Zygote and Swift for Tensorflow.

The only alternative to this piecewise differentiation that I know of would be creating a soft version of discrete operators, such as replacing step functions with sigmoids and case/switch/elsif operators as softmax selectors for example, which is not what any of those libraries do (it would not be easy to make it converge as the graph would be much more complex at each backward pass). In this case you could have one single graph that includes every branch though.