Hacker News new | ask | show | jobs
by sega_sai 230 days ago
I switched from pytorch to jax just before triton appeared. Does anyone know how jax compares to this autotuning machinery in pytorch ? I know jax does jit, but i don't have a good intuition if jit is better than this type of autotuning.
1 comments

Pallas is the Triton equivalent in JAX land. There are some old auto tuning prototypes if you search for Pallas, like this https://github.com/jax-ml/jax-triton/pull/108