|
|
|
|
|
by cs702
916 days ago
|
|
This looks really nice. Thank you for sharing it on HN! In case you didn't know, you can parallelize the slow Python loop in selective_scan that computes all the x's: x = torch.zeros((b, d_in, n))
for i in range(l):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
⋮
with only two calls to the PyTorch API. See the examples here: https://github.com/glassroom/heinsen_sequence/blob/main/READ... .[a]You can then compute all the y's with one einsum, instead of l sequential einsums. --- [a] Previous discussion on HN: https://news.ycombinator.com/item?id=38556669 |
|