Hacker News new | ask | show | jobs
by f38zf5vdt 612 days ago
There is a downside, every attention layer has to effectively compute attention twice (run scaled_dot_product_attention). As scaled_dot_product_attention is usually one of the most expensive operations in training and inference of a model, it seems like networks using this may be substantially slow and perhaps should considered against larger networks with more attention layers.

https://github.com/microsoft/unilm/blob/master/Diff-Transfor...

2 comments

Interesting. This is one of those areas where edge inference needs might be different than data center: to get an 11b quality model in 7b at the cost of 30% more inference time is probably a full yes for anyone doing local inference. And let’s remember that memory bandwidth is a huge factor as well; 30% smaller equals 30% boost in memory based time costs. Anyway I’m interested in trying this out.

I wonder if the specific setup might be extra effective for coding tuned models as well - you get one coding transformer and one ‘bad habits/chat/other non coding stuff’ negative transformer.

In Big-O notation, O(2n) = O(n). Two times slower is actually not that much. If this slowdown results in better inference in the same number of training rounds or better-tuned weights with fewer redundant features, that can be a very worthwhile sacrifice.

It's also a complex optimization problem, not just about computing. Two times, the parameters take more than two times the time to tune and two times the working memory to train and use. There are also plenty of model training scenarios where data throughput from the dataset into memory and back out is the final bottleneck.

So, though I agree it is indeed a downside, I think it's a worthwhile sacrifice if the results they show are reproducible.

Glad to see your ideas here. Could you clarify a point to me? The W matrix in the paper is d_model x 2d. Does this mean a differential attention model will double the W matrix of a standard attention model, which is d_model x d? E g. Suppose llama3 has W of 8192 x 1024, does the diffattn model of the same architecture have W of 8192 x (1024 x 2)?
The O for any transformer is always quadratic