Lots of modern models have very late binding variables
which are hard to precompile for (sentence length in MNT, for example). That means you're going to need to do some form of specialization at runtime, so a JIT makes sense.
One of the core operations of the transformer network[1] is a (LxL) x (LxE) matrix multiply (where L is the sentence length and E is the network width). Can you be more specific about how you would get good performance without specializing on L?
L can be as small as 1 and bigger than 512. For small L it makes sense to do different optimizations than large L. A loop based GEMM doesn’t help with that.
Well, the success of neural nets over the past few years has come through harnessing massive processing power.
The problem is a lot of the programming can be low level and ad-hoc. I think the idea of the various DSLs is to allow the model to be compactly specified while having the programs go as fast as possible. A JIT may be one way to accomplish this.
Relying upon a JIT often means the ability to create things which preclude the use of static compilers, which means that accelerated hardware, like ours, cannot be used efficiently.
The kind of optimizations a static compiler might apply can be done by a JIT as well, with the added benefit of actually knowing what kind of workload is going to run. Most of deep learning is applying comparatively small computation graphs to very large arrays of numbers in parallel, so the overhead of compilation is only a small portion of the overall computation time. A smart JIT that decides on the optimal tiling pattern for the array dimensions observed at runtime and rewrites loops accordingly can easily pay for itself.
As a counterpoint (and not necessarily the one the GP is referring to), if you were compiling for, say, an FPGA the overhead of compilation would be very significant.