Hacker News new | ask | show | jobs
by missingET 995 days ago
Jax has a much nicer handling of higher order differentiation. PyTorch has functions to compute Hessians and there are libraries to keep differentiability through optimizers, but going out of their standard use-cases becomes tricky very fast. In contrast, JAX can compute nth-derivatives of things very easily.