|
|
|
|
|
by kerasteam2
1070 days ago
|
|
At this time, there are no backend-agnostic APIs to implement training steps/training loops, because each backend handles training very differently so no shared abstraction can exist (expecially for JAX). So when customizing fit() you have to use backend-native APIs. If you want to make a model with a custom train_step that is cross-backend, you can do something like: def train_step(self, *args, *kwargs):
if keras.config.backend() == "tensorflow":
return self._tf_train_step(*args, *kwargs)
elif ...
BTW it looks the previous account is being rate-limited to less than 1 post / hour (maybe even locked for the day) so I will be very slow to answer questions. |
|