|
|
|
|
|
by wsmoses
1955 days ago
|
|
Reverse mode AD can always get into situations where it needs to store original values (i.e. network state). One advantage, however, of doing a more whole-program approach to AD rather than individual operators is that one might be able to avoid caching values unnecessarily. For example if an input isn't modified (and still exists) by the time the value is needed in the reverse pass, you don't need to cache it but can simply use the original input without a copy. And yes PyTorch/TF tend to perform a (limited) form of AD as well, rather than do numerical differentiation (though I do think there may be an option for numerical?) I wouldn't really position a tool like Enzyme as a competitor to PyTorch/TF (they may have some better domain-specific knowledge after all), but rather a really nice complement. Enzyme can take derivatives of arbitrary functions, in any LLVM-based language rather than the DSL of operators supported by PyTorch/TF. In fact, we built a plugin for PyTorch/TF that uses Enzyme to import custom foreign code as a differentiable layer! |
|
I was under the impression that the big ML frameworks (and surely JAX with jit) are doing optimization on the complete compute graph, too.
I didn't want to make this discussion too TF/pyTorch focused (I'm not even a ML researcher). But your optimization claims sound like the other AD frameworks are not doing any optimization at all, which is not the case.
I was also thinking about derivatives of functions which are doing something iterative on the inside, like a matrix decomposition (combined with linear solve and/or matrix inversion). While a "high level" AD tracer can identify an efficient derivative of these operations, your LLVM introspection would only be able to compute the derivative through all the internal step of the matrix decomposition?