|
|
|
|
|
by missingET
995 days ago
|
|
Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily. |
|