Hacker News new | ask | show | jobs
by Ari_Rahikkala 1836 days ago
I'm running it comfortably on my 3090, although it's a really snug fit for the VRAM, and that's with a number of fixes to significantly reduce its memory use from https://github.com/AeroScripts/mesh-transformer-jax .
1 comments

Out of curiosity, how fast are your inferences with this setup?
With the defaults of per_replica_batch=1, seq=2048 and gen_len=512, a completion takes about 20 seconds.

I'm not sure yet what settings I'll end up with if I decide to play with this more. per_replica_batch=3, seq=1024, gen_len=64 would give an experience roughly similar to the AI Dungeon that I'm used to, though less clever than the Dragon model, and a bit slower at about 10 seconds per batch.