|
|
|
|
|
by kanaffa12345
1613 days ago
|
|
>jit compile Python given who you are (googling your name) i'm surprised that you would say this. jax does not jit compile python in any sense of the word `Python`. jax is a tracing mechanism for a very particular set of "programs" specified using python; i put programs in quotes because it's not like you could even use it to trace through `if __name__ == "__main__"` since it doesn't know (and doesn't care) anything about python namespaces. it's right there in the first sentence of the description: >JAX is Autograd and XLA autograd for tracing and building the tape (wengert list) and xla for the backend (i.e., actual kernels). there is no sense in which jax will ever play a role in something like faster hash tables or more efficient loads/stores or virtual function calls. in fact it doesn't even jit in the conventional understanding of jit, since there is no machine code that gets generated anew based on code paths (it simply picks different kernels and such that have already been compiled). not that i fault you for this substitution since everyone in ML does this (pytorch claims to jit as well). |
|
You miss my point that all of those efforts are making slow Python code run faster. So claiming that 'these two things have nothing to do with each other' is wrong, because they share 'making Python code run faster'.
Some of that involves making cpython faster, some of that means moving execution into c (numpy is mentioned in that PDF) and some involves jit and moving execution onto GPU or TPU (for example using XLA). The common part is 'making Python code run faster'. Some of that is automatic, some requires some manual effort.
Jax can jit some Python functions, but it cannot efficiently jit everything. That is what I meant by decoration and 'some effort'. For example replacing IF conditions by np.where etc. See also https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
My background is in physics simulation, and I advise the Brax team, basically accelerating a physics engine written in Python run on accelerators, see https://github.com/google/brax The entire physics step, including collision detection and physics solver, is jit compiled.