Hacker News new | ask | show | jobs
by bodono 2099 days ago
I'm not sure this is really going to take off, it seems that most people who are abandoning TF are moving to Jax or pytorch. My own experience with Jax is that it is much easier to use then TF, just an all round more pleasant experience. It would be interesting to try this, but at this point I'm not really willing to learn 'yet another deep learning framework' and the extreme anti-user problems that TF had make me loath to give it another shot, even with a presumably better frontend. Moreover, I think that python is just a better all-round ML/data science language at this point. Has anyone tried both Jax and this and would be willing to give us their thoughts on strengths and weaknesses of each?
5 comments

I'm skeptical of JAX. It feels good right now, but when the first TF beta version came out it was very much like that too - clean, simple, minimal, and just a better version of Theano. Then the "crossing the chasm" effort started and everyone at Google wanted to be part of it, making TF the big complex mess it is today. It's a great example of Conway's Law. I'm not convinced the same won't happen to JAX as it catches on.

PyTorch has already stood the test of time and proven that its development is led by a competent team.

I know where you're coming from, but TF in my opinion was very user-hostile even on arrival. I can't tell you how much hair-pulling I did over tf.conds, tf.while_loops and the whole gather / scatter paradigm for simple indexing into arrays. I really think the people working on it wanted users to write TF code in a certain, particular way and made it really difficult to use it in other ways. Just thinking back on that time still raises my blood pressure! So far Jax is much better and I'm cautiously optimistic they have learned lessons from TF.
I had the opposite experience. The early TF versions were difficult to use in that they required a lot of boilerplate code to do simple things, but at least there was no hidden complexity. I knew exactly what my code did and what was going on under the hood. When I use today's high-level opaque TF libraries I have no idea what's going on. It's much harder to debug subtle problems. The workflow went wrong "Damn, I need to write 200 lines of code to do this simple thing" to "I need to spend 1 hour looking through library documentations, gotchas, deprecation issues and TF-internal code to figure out which function to call with what parameters and check if it actually does exactly what I need" - I much prefer the former.

Having barriers of entry is not always a bad thing - it forces people to learn and understand concepts instead of blindly following and copying and pasting code from a Medium article and praying that it works.

But I agree with you that there are many different use cases. Those people who want to do high-level work (I have some images, just give me a classifier) shouldn't need to deal with that complexity. IMO the big mistake was trying to merge all these different use cases into one framework. Let's hope JAX doesn't go down the same route.

(googler)

Not quite sure why you picked those particular examples... JAX also requires usage of lax.cond, lax.while_loop, and ops.segment_sum. Only gather has been improved with slice notation support. IMO, TF has landed on a pretty nice solution to cond/while_loop via AutoGraph.

While jax has those operations you don't always need them, it depends on what transformations you want to do (JIT or grad) and they have been working on making normal control structures compatible with all transformations
You can't blame the TF people for things like while_loop. Those are inherited from Theano, and back then the dynamic graph idea wasn't obvious.

JAX is indeed a different situation as it has a more original design (although TF1 came with a huge improvement in compilation speed, so maybe there were innovations under the hood). But I don't know if I like it. The framework itself is quite neat, but last time I checked, the accompanying NN libraries had horrifying designs.

> tf.conds, tf.while_loops and the whole gather / scatter paradigm

I'm ill-informed - but isn't that exactly what lax is?

The difference is that in TF1 you had to use tf.cond, tf.while_loop etc for differentiable control flow. In JAX you can differentiate Python control flow directly, e.g.:

  In [1]: from jax import grad
  
  In [2]: def f(x):
     ...:     if x > 0:
     ...:         return 3. * x ** 2
     ...:     else:
     ...:         return 5. * x ** 3
     ...:
  
  In [3]: grad(f)(1.)
  Out[3]: DeviceArray(6., dtype=float32)
  
  In [4]: grad(f)(-1.)
  Out[4]: DeviceArray(15., dtype=float32)
In the above example, the control flow happens in Python, just as it would in PyTorch. (That's not surprising, since JAX grew out of the original Autograd [1]!)

Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)

Disclaimer: I work on JAX!

[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics

What would you say the main advantage of Jax is over Pytorch?
> I'm not convinced the same won't happen to JAX

And now there are already multiple NN libraries for JAX from Google...

There are a bunch of frameworks built on top of Pytorch too (fastAI, lighting, torchbearer, ignite...), I don't see why this should be a problem (or at least a problem to JAX but not to Pytorch)
IMO, this is not a fair comparison because Pytorch spans a larger amount of abstraction than jax (I don't quite know how to explain it other than "spans a larger amount of abstraction").

You can do much of the jax stuff in pytorch, you can't do the high level nn.LSTM stuff in jax, you have to use like flax or objax or something.

All I want is a way to statically type check tensor axes. Why can't I get a way to statically type check tensors?
There are a few efforts working in this space. As you can imagine, all of them are experimental:

- Dex: https://github.com/google-research/dex-lang/ - Hasktorch: https://github.com/hasktorch/hasktorch - This initiative from the Python Typing-sig: https://docs.google.com/document/d/1oaG0V2ZE5BRDjd9N-Tr1N0IK...

This is not statically checked but it's a step in the right direction: https://pytorch.org/docs/stable/named_tensor.html
Yeah, I actually helped work on the inspo for that project https://github.com/harvardnlp/namedtensor .

From what I've been able to tell, (no shade to the Pytorch team which has many different priorities) work has been somewhat slow going on the port.

Further, this is dynamic type checking as you mentioned.

I see, interesting! Yeah statically checking this would be way more awesome still
Oh I just noticed that you're one of the people behind that recent GAN compression work! Really cool stuff and a big step up this year, I've been following the field for a lil bit.

Congrats!

Thanks a lot for the kind words!
The subtext is Google would love even more Google projects to be ml prerequisites.
I have just started hearing about Jax. But it seems to be a low level library that Tensorflow uses right ?

The latest release of Tensorflow probability uses JAX under the hood. So what do you mean when you say you're moving to JAX versus Tensorflow

In your first sentence you're mistaking JAX and XLA

XLA: Accelerated Linear Algebra, I guess it's kind of a backend/compiler that optimizes Linear Algebra/Deep Learning calculations with some very interesting techniques, among them fusing kernels

JAX: In some sense syntax sugar over XLA, but a better way of describing it is Composable transformations + Numpy + some Scipy. The composable transformations allow you to take derivatives (be them single, multi or vector valued functions and also higher order derivatives), JIT a function (which is them compiled to XLA), 2 forms of parallelism (vmap and pmap) and others, while being compatible with one another and with both TPUs, GPUs and CPUs

im not mistaking the articles around it - check this out: https://www.tensorflow.org/probability/examples/TensorFlow_P...

"TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now works on JAX! For those not familiar, JAX is a library for accelerated numerical computing based on composable function transformations.

We have ported a lot of TFP's most useful functionality to JAX while preserving the abstractions and APIs that many TFP users are now comfortable with."

Tensorflow is migrating a bunch of stuff to JAX. Even they use the "library" word for their own porting. For a user like me, it looks like Jax is a library that tensorflow uses...but the end-user usable library is tensorflow.

Hi, tech lead for TFP here. The wording here was unclear -- sorry! We're fixing it presently.

We are not migrating away from TF; far from it!

The change here was to interoperate with TF and JAX (and numpy!), by way of some rewrite trickery under the hood. Essentially, we wrote a translation layer that implements the TF API surface (or, the parts we actually use) in terms of numpy & JAX primitives [1]. This lets us leave most TFP code intact, written in terms of the TF API, but interoperate with JAX by way of the API translation layer. (Actually we implemented numpy support first, and mostly got JAX for "free" since JAX is largely API-compatible with numpy).

Sorry for any confusion!

We're pretty stoked about this work, so happy to answer any other questions you may have (also feel free to chime in on the github tracker or email tfprobability@tensorflow.org)

[1] - https://github.com/tensorflow/probability/tree/master/tensor...

hey thanks for the clarification.

here's what everybody is puzzled on: it looks like the layers going forward are JAX -> Tensorflow -> Keras.

and we are seeing people moving to JAX directly. So this is ending up like a Flutter vs Kotlin issue (also within Google).

Do you envision JAX being low level .. and the high level tensorflow keras interface being the most usable api ?

I don't think that the main goal of TF on Swift is to train models using Swift. I think it's mainly to deploy them in production on iPhones