Hacker News new | ask | show | jobs
by wsmoses 1965 days ago
Hi all, another author here and happy to answer any questions!

Some more relevant links for the curious

Github: https://github.com/wsmoses/Enzyme

Paper: https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b...

Basically the long story short is that Enzyme has a couple of interesting contributions:

1) Low-level Automatic Differentiation (AD) IS possible and can be high performance

2) By working at LLVM we get cross-language and cross-platform AD

3) Working at the LLVM level actually can give more speedups (since it's able to be performed after optimization)

4) We made a plugin for PyTorch/TF that uses Enzyme to import foreign code into those frameworks with ease!

4 comments

Hey, very interesting work!

CPython is build in C. Can you differentiate through that? I.e. then Python programs also become differentiable? Similar as JAX.

How much control do you have about the gradient? In some cases, it can be useful to explicitly define a custom gradient, or to stop the gradient, or to change the gradient, etc.

Can you define gradients on integral types (int, char)?

Regarding differentiating python via CPython, theoretically yes, though practically it is likely more wise to use something like Numba which takes Python to LLVM directly to avoid a bunch of abstraction overhead that would otherwise have to be differentiated through. Also fun fact JaX can be told to simply emit LLVM and we've used that as an input for tests :)

You can explicitly define custom gradients by attaching metadata to the function you want to have the custom gradient (and Enzyme will use that even if it could differentiate the original function).

Integral types: mayyybe, depending what exactly you mean. I can imagine using custom gradient definitions to try specifying how an integral type can be used in a differentiable way (say representing a fixed point). We don't support differentiating integral types by approximating them as continuous values if that's what you're asking. There's no reason why we couldn't add this (besides perhaps bit tricks being annoying to differentiate), but haven't come across a use case.

Yea, I had this very rough (maybe crazy) idea in mind:

Once you can differentiate through CPython, and let's say you can also differentiate integral types via some approximation, and you have some bug in some Python code, and a failing test case in Python, you can use the output (e.g. exception of the failing test) as an error signal and backpropagate to the Python program code. The Python program code is represented as a chunk of bytes. If there is some meaningful gradient, it could point you to possible source code locations where the bug might be.

Probably the gradient will be quite meaningless though, and that's why the idea does not really work in practice. But I think for some simple examples, it still might work.

For any possible branches in the code (and there are a lot), to get a good approximated gradient, you should visit some of the branches, maybe some MC sampling or so.

Hello,

Thank you for sharing and releasing usable code! Do you know if this would work for GPU based applications? Tensorflow models that are trained on a GPU, for example?

For GPU's, there's a couple of different things that you might want to do.

You can use existing tools within LLVM to automatically generate GPU code out of existing code, and this works perfectly fine, even running Enzyme first to synthesize the derivative.

You can also consider taking an existing GPU kernel and then automatically differentiating it. We currently support a limited set of cases for this (certain CUDA instructions, shared memory etc), and are working on expanding as well as doing performance improvements. AD of existing general GPU kernels is interesting [and more challenging] since racey reads in your original code become racey writes in the gradient -- which must have extra care taken to make sure they don't conflict. To my knowledge GPU AD on general programs (e.g. not a specific code) really hasn't been done before, so it's a fun research problem to work on (and if someone knows of existing tools for this please email me at wmoses at mit dot edu).

Very interesting. I especially like the second item.

What happens if the function you want to differentiate calls multiple other functions, in multiple other compilation units?

(I haven't read the paper yet but definitely will)

Enzyme needs to be able to access the IR of any potentially active functions (calls that it deduced could impact the gradient) to be able to differentiate them.

If all of the code you care about is in one compilation unit, you're immediately good to go.

Multiple compilation units can be handled in a couple of ways, depending on how much energy you want to set it up (and we're working on making this easier).

The easiest way is to compile with Link-Time Optimization (LTO) and have Enzyme run during LTO, which ensures it has access to bitcode for all potentially differentiated functions.

The slightly more difficult approach is to have Enzyme ahead-of-time rather than lazily emit derivatives for any functions you may call in an active way (and incidentally this is where Enzyme's rather aggressive activity analysis is super useful). Leveraging Enzyme's support for custom derivatives in which an LLVM function declaration can have metadata that marks its derivative function, Enzyme can then be told to use the "custom" derivatives it generated while compiling other compilation units. This obviously requires more setup so I'm usually lazy and use LTO, but this can definitely be made easier as a workflow.

Thanks. LTO definitely looks like the more natural option.
What are the limitations? When will it fail?
We go into more details in the Limitations section of the paper, but in short Enzyme requires the following properties:

* IR of active functions must be accessible when Enzyme is called (e.g. cannot differentiate dlopen'd functions)

* Enzyme must be able to deduce the types of operations being performed (see paper section on interprocedural type analysis for details why)

* Support for exceptions is limited (and running with -fno-exceptions, equivalent in a diff language, or LLVM's exception lowering pass removes these).

* Support for parallel code (CPU/GPU) is ongoing [and see the prior comment on GPU parallelism for details]