|
|
|
|
|
by ddragon
2502 days ago
|
|
Do you have an example of a tensor library that keep track of shapes and detect mismatches at compile time? I had the impression that even in static languages having tensors with the exact shape as a parameter would stress the compiler, forcing it to compile many versions of every function for every possible size combination, and the output of a function could very well have a non deterministic or multiple possible shapes (for example branching on runtime information). So they compromise and make only the dimensionality as a parameter, which would not catch your example either until the runtime bound checks. |
|
> I had the impression that even in static languages having tensors with the exact shape as a parameter would stress the compiler, forcing it to compile many versions of every function for every possible size combination, and the output of a function could very well have a non deterministic or multiple possible shapes (for example branching on runtime information).
I was a bit lazy in my original comment - you're right. What I really think should be implemented (and is already starting to in Pytorch and a library named NamedTensor, albeit non-statically) is essentially having "typed axes."
For instance, if I had a sequence of locations in time, I could describe the tensor as:
(3 : DistanceAxis, 32 : TimeAxis, 32 : BatchAxis).
Sure, the number of dimensions could vary and you're right that, if so, the approach implied by my first comment would have a combinatorial explosion. But if I'm contracting a TimeAxis with a BatchAxis accidentally, that can be pretty easily caught before I even have to run the code. But in normal pytorch, such a contraction would succeed - and it would succeed silently.