ray.train.tensorflow.keras.ReportCheckpointCallback#

class ray.train.tensorflow.keras.ReportCheckpointCallback(*args: Any, **kwargs: Any)#

Bases: _Callback

Keras callback for Ray Train reporting and checkpointing.

Note

Metrics are always reported with checkpoints, even if the event isn’t specified in report_metrics_on.

Example

############# Using it in TrainSession ###############
from ray.air.integrations.keras import ReportCheckpointCallback
def train_loop_per_worker():
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()

    model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
Parameters:
  • metrics – Metrics to report. If this is a list, each item describes the metric key reported to Keras, and it’s reported under the same name. If this is a dict, each key is the name reported and the respective value is the metric key reported to Keras. If this is None, all Keras logs are reported.

  • report_metrics_on – When to report metrics. Must be one of the Keras event hooks (less the on_), e.g. “train_start” or “predict_end”. Defaults to “epoch_end”.

  • checkpoint_on – When to save checkpoints. Must be one of the Keras event hooks (less the on_), e.g. “train_start” or “predict_end”. Defaults to “epoch_end”.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods