ray.train.report#

ray.train.report(metrics: Dict[str, Any], checkpoint: Checkpoint | None = None, checkpoint_dir_name: str | None = None)#

Report metrics and optionally save a checkpoint.

If a checkpoint is provided, it will be persisted to storage.

If this is called in multiple distributed training workers:

  • Only the metrics reported by the rank 0 worker will be attached to the checkpoint.

  • A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the checkpointing guide.

  • Checkpoints from multiple workers will be merged into one directory in persistent storage. See the distributed checkpointing guide.

Warning

All workers must call ray.train.report the same number of times so that Ray Train can properly synchronize the training state across workers. This method acts as a barrier across all workers, so be sure that every worker reaches this method.

Example

import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state
            ...

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...
           # torch.save(...)

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Example: Only the rank 0 worker uploads the checkpoint.
            if ray.train.get_context().get_world_rank() == 0:
                train.report(metrics, checkpoint=checkpoint)
            else:
                train.report(metrics, checkpoint=None)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Parameters:
  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.