Hacker News new | ask | show | jobs
by joaogui1 2099 days ago
In your first sentence you're mistaking JAX and XLA

XLA: Accelerated Linear Algebra, I guess it's kind of a backend/compiler that optimizes Linear Algebra/Deep Learning calculations with some very interesting techniques, among them fusing kernels

JAX: In some sense syntax sugar over XLA, but a better way of describing it is Composable transformations + Numpy + some Scipy. The composable transformations allow you to take derivatives (be them single, multi or vector valued functions and also higher order derivatives), JIT a function (which is them compiled to XLA), 2 forms of parallelism (vmap and pmap) and others, while being compatible with one another and with both TPUs, GPUs and CPUs

1 comments

im not mistaking the articles around it - check this out: https://www.tensorflow.org/probability/examples/TensorFlow_P...

"TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now works on JAX! For those not familiar, JAX is a library for accelerated numerical computing based on composable function transformations.

We have ported a lot of TFP's most useful functionality to JAX while preserving the abstractions and APIs that many TFP users are now comfortable with."

Tensorflow is migrating a bunch of stuff to JAX. Even they use the "library" word for their own porting. For a user like me, it looks like Jax is a library that tensorflow uses...but the end-user usable library is tensorflow.

Hi, tech lead for TFP here. The wording here was unclear -- sorry! We're fixing it presently.

We are not migrating away from TF; far from it!

The change here was to interoperate with TF and JAX (and numpy!), by way of some rewrite trickery under the hood. Essentially, we wrote a translation layer that implements the TF API surface (or, the parts we actually use) in terms of numpy & JAX primitives [1]. This lets us leave most TFP code intact, written in terms of the TF API, but interoperate with JAX by way of the API translation layer. (Actually we implemented numpy support first, and mostly got JAX for "free" since JAX is largely API-compatible with numpy).

Sorry for any confusion!

We're pretty stoked about this work, so happy to answer any other questions you may have (also feel free to chime in on the github tracker or email tfprobability@tensorflow.org)

[1] - https://github.com/tensorflow/probability/tree/master/tensor...

hey thanks for the clarification.

here's what everybody is puzzled on: it looks like the layers going forward are JAX -> Tensorflow -> Keras.

and we are seeing people moving to JAX directly. So this is ending up like a Flutter vs Kotlin issue (also within Google).

Do you envision JAX being low level .. and the high level tensorflow keras interface being the most usable api ?