ray.train.tensorflow.keras.ReportCheckpointCallback#
- class ray.train.tensorflow.keras.ReportCheckpointCallback(*args: Any, **kwargs: Any)#
Bases:
_Callback
Keras callback for Ray Train reporting and checkpointing.
Note
Metrics are always reported with checkpoints, even if the event isn’t specified in
report_metrics_on
.Example
############# Using it in TrainSession ############### from ray.air.integrations.keras import ReportCheckpointCallback def train_loop_per_worker(): strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
- Parameters:
metrics – Metrics to report. If this is a list, each item describes the metric key reported to Keras, and it’s reported under the same name. If this is a dict, each key is the name reported and the respective value is the metric key reported to Keras. If this is None, all Keras logs are reported.
report_metrics_on – When to report metrics. Must be one of the Keras event hooks (less the
on_
), e.g. “train_start” or “predict_end”. Defaults to “epoch_end”.checkpoint_on – When to save checkpoints. Must be one of the Keras event hooks (less the
on_
), e.g. “train_start” or “predict_end”. Defaults to “epoch_end”.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods