| Hello! Thanks for your question. First of all, there are three layers of abstraction within Caten: 1. caten/apis | High-Level Graph Interface
2. caten/air | Low-Level Graph Interface
3. caten/codegen | AIR Graph => Kernel Generator The inputs of the compiler are just Common Lisp classes (similar to torch modules). For example, in Common Lisp, we could create a module that does SinCos: (defclass SinCos (Func) nil
(:documentation "The func SinCos computes sin(cos(x))"))
;; Forward creates a lazy tensor for the next computation.
;; You can skip this process by using the `st` macro.
(defmethod forward ((op SinCos) &rest tensors)
(st "A[~] -> A[~]" (tensors)))
;; Backward is optional (skipped this time)
(defmethod backward ((op SinCos) &optional prev-grad)
(declare (ignore prev-grad))
nil)
;; Lower describes the lowered expression of `SinCos`
(defmethod lower ((op SinCos) &rest inputs)
(let ((x (car inputs)))
(with-context
(a (%sin (%add x (%fconst (/ pi 2)))))
(b (%sin a)))))
The `apis` layer is the high-level interface, while the `lower` method is the lower-level step before code generation.Next, the framework generates an Abstract VM (AVM) representation: #S(AVM :GRAPH Graph[seen=NIL, outputs=(STC6466_1)] {
<ALLOCATE : TID6464 <- (shape=(1), stride=(1)) where :dtype=FLOAT32>
<Node[BUFFER] ALLOCATE(NID6480) : SID6479* <- ()>
<Node[BINARYOPS] ADD(NID6484) : BID6483* <- (TID6464, LID6481)>
<Node[UNARYOPS] SIN(NID6486) : UID6485* <- (BID6483)>
<Node[UNARYOPS] SIN(NID6488) : UID6487* <- (UID6485)>
<Node[SPECIAL/VM] PAUSE/BACKWARD(NID6501) : STC6466_1* <- (UID6487)>
})
Then, the computation graph is translated into schedule items: FastGraph[outputs=(val_6)] {
{ Allocate } : [ val_0 <- (1) ]
{ KERNEL } : [ val_5 <- val_1, val_0 :name=FUSED_SIN_SIN_ADD_LOAD6511]
}
Finally, the code generation step produces the following C code: void fused_sin_sin_add_load6511(float* val_5, const float* restrict val_0);
void fused_sin_sin_add_load6511(float* val_5, const float* restrict val_0) {
val_5[0] = sin(sin((val_0[0] + 1.5707964)));
}
This C code is compiled by a C compiler and executed.So to answer your question: the compiler takes Common Lisp code and generates C functions. |