| I do find the pattern interesting and powerful. But at the same time, something feels off about it (just conceptually, not trying to knock your money-making endeavor, godspeed). Some of the issues that all of these hit is: - No printf debugging. Sometimes you want things to be eager so you can immediately see what's happening. If you print and what you see is <RPCResultTracingObject> that's not very helpful. But that's what you'll get when you're in a "tracing" context, i.e. you're treating the code as data at that point, so you just see the code as data. One way of getting around this is to make the tracing completely lazy, so no tracing context at all, but instead you just chain as you go, and something like `print(thing)` or `thing.execute()` actually then ships everything off. This seems like how much of Cap'n Web works except for the part where they embed the DSL, and then you're in a fundamentally different context. - No "natural" control flow in the DSL/tracing context. You have to use special if/while/for/etc so that the object/context "sees" them. Though that's only the case if the control flow is data-dependent; if it's based on config values that's fine, as long as the context builder is aware. - No side effects in the DSL/tracing context because that's not a real "running" context, it's only run once to build the AST and then never run again. Of the various flavors of this I've seen, it's the ML usage I think that's pushed it the furthest out of necessity (for example, jax.jit https://docs.jax.dev/en/latest/_autosummary/jax.jit.html, note the "static*" arguments). Is this all just necessary complexity? Or is it because we're missing something, not quite seeing it right? |
Python does let you mess around with the AST, however, there is no static typing, and let's just say that the ML ecosystem will <witty example of extreme act> before they adopt static typing. So it's not possible to build these graphs without doing this kind of hacky nonsense.
For another example, torch.compile() works at the python bytecode level. It basically monkey patches the PyEval_EvalFrame function evaluator of Cpython for all torch.compile decorated functions. Inside that, it will check for any operators e.g BINARY_MULTIPLY involving torch tensors, and it records that. Any if conditions in the path get translated to guards in the resulting graph. Later, when said guard fails, it recomputes the subgraph with the complementary condition (and any additional conditions) and stores this as an alternative JIT path, and muxes these in the future depending on the two guards in place now.
Jax works by making the function arguments proxies and recording the operations like you mentioned. However, you cannot use normal `if`, you use lax.cond(), lax.while(), etc,. As a result, it doesn't recompute graph when different branches are encountered, it only computes the graph once.
In a language such as C#, Rust, or a statically typed lisp, you wouldn't need to do any of this monkey business. There's probably already a way in the rust toolchain to interject at the MIR stage and have your own backend convert these to some Tensor IR.