Hacker News new | ask | show | jobs
by Evidlo 605 days ago
I have a non-ML question.

In vanilla Pytorch I have the following expression:

    t.sum(values[inds] * weights)
If 'inds' is int8, I get "IndexError: tensors used as indices must be long, int, byte or bool tensors".

Is this still true if I use torchao?

1 comments

The issue here is memory in PyTorch is byte addressable and that's a limitation we can't solve without making a lot more changes to PyTorch. But in your specific case, if you'd like to pack more data into `values` you can use a combination of clever bit shifting, torch.cat and other bit twiddling pytorch like ops to pack more data. It's a trick we use quite heavily in torchao
Arent int8s byte-aligned though? I thought this restriction was originally motivated by maintenance overhead of having to support more dtypes.