Hacker News new | ask | show | jobs
by thepasswordis 1842 days ago
Is it possible to run this on something other than google's cloud platform?
2 comments

I'm running it comfortably on my 3090, although it's a really snug fit for the VRAM, and that's with a number of fixes to significantly reduce its memory use from https://github.com/AeroScripts/mesh-transformer-jax .
Out of curiosity, how fast are your inferences with this setup?
With the defaults of per_replica_batch=1, seq=2048 and gen_len=512, a completion takes about 20 seconds.

I'm not sure yet what settings I'll end up with if I decide to play with this more. per_replica_batch=3, seq=1024, gen_len=64 would give an experience roughly similar to the AI Dungeon that I'm used to, though less clever than the Dragon model, and a bit slower at about 10 seconds per batch.

It's design is pretty specific to Google's TPU.

If you ran it elsewhere, you'd likely need to tweak the design or suffer quite a big performance penalty.

Nope, probably not. It's using JAX, which works on both GPUs and TPUs.
All the constants in the design will be tuned to fit perfectly in TPU hardware dimensions.