|
|
|
|
|
by time_to_smile
1231 days ago
|
|
Backprop is just an implementation detail when doing automatic differentiation, basically setting up how you would apply the chain rule to your problem. JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away. If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm. This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details. |
|