|
|
|
|
|
by antognini
399 days ago
|
|
Once you have the matrix implementation in Step 2 (Implementation 3) it's rather straightforward to extend your N-body simulator to run on a GPU with Jax --- you can just add `import jax.numpy as jnp` and replace all the `np.`s with `jnp`s. For a few-body system (e.g., the Solar System) this probably won't provide any speedup. But once you get to ~100 bodies you should start to see substantial speedups by running the simulator on a GPU. |
|
For the programmer, yes it is easy enough.
But there is a lot of complexity hidden behind that change. If you care about how your tools work that might be a problem (I'm not judging).