Hacker News new | ask | show | jobs
by ChrisRackauckas 1792 days ago
"Maybe they let you declare some subgraph as 'dynamic' to avoid static optimizations?" What you just described is Tensorflow Eager and why it has some performance issues (but more flexibility!). XLA makes some pretty strong assumptions and I don't think that should change. Tensorflow's ability to automatically generate good automatically parallelized production code stems from the restrictions it has imposed. So I wouldn't even try for a "one true AD to rule them all" since making things more flexible will reduce the amount of compiler optimizations that can be automatically performed.

To get the more flexible form, you really would want to do it in a way that uses a full programming language's IR as its target. I think trying to use a fully dynamic programming language IR directly (Python, R, etc.) directly would be pretty insane because it would be hard to enforce rules and get performance. So some language that has a front end over an optimizing compiler (LLVM) would probably make the most sense. Zygote and Diffractor uses Julia's IR, but there are other ways to do this as well. Enzyme (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR directly for doing source-to-source translations. Using some dialect of LLVM (provided by MLIR) might be an interesting place to write a more ML-focused flexible AD system. Swift for Tensorflow used the Swift IR. This mindset starts to show why those tools were chosen.

1 comments

Makes sense. I don't use TF Eager, but I do use Jax, and Jax lets you arbitrarily compose JITed and non-JITed code, which made me think that might be a viable pattern. I guess I wondered if there might be something like "nonstatic_jit(foo)" that would do "julia style" compiling on function foo, in addition to "jit(foo)" that compiles foo to optimized XLA ops. Probably impractical. Thanks.
This is entirely possible. You have direct access to XLA, after all: https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python...

And XLA has dynamic shape semantics (currently unused by jax) via SetDimensionSize: https://www.tensorflow.org/xla/operation_semantics#setdimens...