Hacker News new | ask | show | jobs
by evnc 784 days ago
When I see an embedded DSL passed around as strings like this I can't help but think "this could be its own programming language"

Then it could have syntax highlighting, auto complete, and so on. The type system for such a language could possibly include verifying shapes at compile time.

What would a language comprised of .ein source files for manipulating tensors, which compiles down to the same low level ops, look like?

5 comments

No need for .ein source files. We just need a programming language that allows the definition of embedded DSLs without shoving them into one-line strings. A language like Common Lisp.

Here's einsum in 200 lines of Common Lisp. All einsum expressions are statically analyzed, checked for errors, and AOT compiled to machine code: https://github.com/quil-lang/magicl/blob/master/src/high-lev...

This is also how it works in Julia, where macros digest notation for einsum-like operations before compile-time. In fact the linked file's explanatory comment:

     (einsum (A i j) (B i k) (C k j)) 
    results in the the updates
      A[i,j] = \\sum_k B[i,k]C[k,j],
    which is equivalent to matrix multiplication.
very nearly contains the syntax used by all the Julia packages (where @ marks a macro), which is

    @tensor A[i,j] = B[i,k] * C[k,j]
(using https://github.com/Jutho/TensorOperations.jl, but see also OMEinsum, Finch, and my Tullio, TensorCast.)
I wrote a library in C++ (I know, probably a non-starter for most reading this) that I think does most of what you want, as well as some other requests in this thread (generalized to more than just multiply-add): https://github.com/dsharlet/array?tab=readme-ov-file#einstei....

A matrix multiply written with this looks like this:

    enum { i = 2, j = 0, k = 1 };
    auto C = make_ein_sum<float, i, j>(ein<i, k>(A) * ein<k, j>(B));
Where A and B are 2D arrays. This is strongly typed all the way through, so you get a lot of feedback at compile time, and C is 2D array object at compile time. It is possible to make C++ template errors reasonable with enable_if and the like, this works well-ish on clang, but not so well in GCC (YMMV).

This library is a lot less automated than most other einsum implementations. You have to explicitly control the loop ordering (in the example above, the `j` loop is innermost because it is loop 0). If you build a good loop order for your shapes, the compiler will probably autovectorize your inner loop, and you'll get pretty good performance. Control over the loop ordering is in general a useful tool, but it's probably a lot lower level than most users want.

I played around with the idea of a language motivated by this same thought process last year: https://github.com/lukehoban/ten.

* Succinct syntax and operators tailored to AI model definition

* Fully statically typed tensors, including generic functions over tensor dimension and batch dimensions (...)

* First-class hyper-parameters, model parameters and model arguments for explicit model specification

* Einops-style reshaping and reductions - tensor dimensions are explicit not implicit

Hey, this is neat! Thanks for sharing. I may be interested in collaborating with you on this when I have some free time.
Einsums are the regexes of tensor programming. Should be avoided at all costs IMO. Ideally we should be able to write native loops that get auto-vectorized into einsums for which there is a CUDA/PTX emitting factory. But for some reason neither PyTorch nor JAX/TF took this route and now we are here.

Some of the einsum expressions I have seen for grouped multi headed/query attention is mind-boggling and they get shipped to prod.

JAX kind of did take this route, no? The main issue is that it’s going to be hard/impossible to compile Python loops to GPU kernels. It’s also maybe not the most ergonomic solution, which is why there is shorthand like einsum. Einsum can be much more clear than a loop because what it can do is so much more limited.
JAX tries to be a functional language that has a Python front end. The problem is if you are outside Google and don't really understand the XLA compiler, then you are screwed.
I like einsum, it's concise and self-explanatory. Way better than multiple lines of nested for loops would be.
I agree that nested loops are verbose. But einsum is unstructured and not extensible, hard to debug in pieces, hard to document etc.,
That sounds very overkill for something that already is overrkill for most ppliciations.