|
|
|
|
|
by time_to_smile
1652 days ago
|
|
> Every day spent learning JAX To get started JAX is just knowing Python and adding `grad`, `jit` and `vmap` to the mix, it takes about 5 minutes to get going. To me this is the real power of JAX, it can be viewed as a few functions that make it easy to take any python code you've written and work with derivatives using that. This gives it tremendous flexibility in helping you solve problems. As an example, I mostly do statistical work with it, rather than NN focused work. It took probably a few minutes to implement a GLM with custom priors over all the parameters, and the use then Hessian for the Laplace approximation of parameter uncertainty. The proper way to solve this would have been using PyMC but this worked good enough for me, and building the model in scratch in JAX took less time than refreshing the PyMC api for me. |
|