|
|
|
|
|
by jxjnskkzxxhx
426 days ago
|
|
I've used Jax quite a bit and it's so much better than tf/pytorch. Now for the life of me, I still haven't been able to understan what a TPU is. Is it Google's marketing term for a GPU? Or is it something different entirely? |
|
So GPUs have ~120 small systolic arrays, one per SM (aka, a tensorcore), plus passable off-chip bandwidth (aka 16 lines of PCI).
Where has TPUs have one honking big systolic array, plus large amounts of off-chip bandwidth.
This roughly translates to GPUs being better if you're doing a bunch of different small-ish things in parallel, but TPUs are better if you're doing lots of large matrix multiplies.