Hacker News new | ask | show | jobs
by cs702 916 days ago
This looks really nice. Thank you for sharing it on HN!

In case you didn't know, you can parallelize the slow Python loop in selective_scan that computes all the x's:

  x = torch.zeros((b, d_in, n))
  for i in range(l):
      x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
      ⋮ 
with only two calls to the PyTorch API. See the examples here: https://github.com/glassroom/heinsen_sequence/blob/main/READ... .[a]

You can then compute all the y's with one einsum, instead of l sequential einsums.

---

[a] Previous discussion on HN: https://news.ycombinator.com/item?id=38556669

1 comments

OP's code is much easier to understand, though, which is the main (only) purpose of their code
Can't argue with that! :-)

For what it's worth, you can keep both, and make parallel vs sequential execution an option, with a boolean flag.

You can also leave the sequential code as a comment explaining what the parallel code does.

Or, if slow execution doesn't bother you, leave it as is.

You're replying to somebody who was arguing for readability being its virtue and you're proposing ... adding options and alternate code paths? :)
Touché. I just updated my comment :-)
Via a boolean parameter, no less.