The difference is that in TF1 you had to use tf.cond, tf.while_loop etc for differentiable control flow. In JAX you can differentiate Python control flow directly, e.g.:
In [1]: from jax import grad
In [2]: def f(x):
...: if x > 0:
...: return 3. * x ** 2
...: else:
...: return 5. * x ** 3
...:
In [3]: grad(f)(1.)
Out[3]: DeviceArray(6., dtype=float32)
In [4]: grad(f)(-1.)
Out[4]: DeviceArray(15., dtype=float32)
In the above example, the control flow happens in Python, just as it would in PyTorch. (That's not surprising, since JAX grew out of the original Autograd [1]!)
Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)
Structured control flow functions like lax.cond, lax.scan, etc exist so that you can, for example, stage control flow out of Python and into an end-to-end compiled XLA computation with jax.jit. In other words, some JAX transformations place more constraints on your Python code than others, but you can just opt into the ones you want. (More generally, the lax module lets you program XLA HLO pretty directly [2].)
Disclaimer: I work on JAX!
[1] https://github.com/hips/autograd [2] https://www.tensorflow.org/xla/operation_semantics