|
|
|
|
|
by joaogui1
2099 days ago
|
|
In your first sentence you're mistaking JAX and XLA XLA: Accelerated Linear Algebra, I guess it's kind of a backend/compiler that optimizes Linear Algebra/Deep Learning calculations with some very interesting techniques, among them fusing kernels JAX: In some sense syntax sugar over XLA, but a better way of describing it is Composable transformations + Numpy + some Scipy. The composable transformations allow you to take derivatives (be them single, multi or vector valued functions and also higher order derivatives), JIT a function (which is them compiled to XLA), 2 forms of parallelism (vmap and pmap) and others, while being compatible with one another and with both TPUs, GPUs and CPUs |
|
"TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now works on JAX! For those not familiar, JAX is a library for accelerated numerical computing based on composable function transformations.
We have ported a lot of TFP's most useful functionality to JAX while preserving the abstractions and APIs that many TFP users are now comfortable with."
Tensorflow is migrating a bunch of stuff to JAX. Even they use the "library" word for their own porting. For a user like me, it looks like Jax is a library that tensorflow uses...but the end-user usable library is tensorflow.