ray.train.report(metrics: Dict, *, checkpoint: Checkpoint | None = None) None#

Report metrics and optionally save a checkpoint.

If a checkpoint is provided, it will be persisted to storage.

If this is called in multiple distributed training workers:


Each invocation of this method will automatically increment the underlying training_iteration number. The physical meaning of this “iteration” is defined by user depending on how often they call report. It does not necessarily map to one epoch.


All workers must call ray.train.report the same number of times so that Ray Train can properly synchronize the training state across workers. Otherwise, your training will hang.


This method does NOT act as a barrier for distributed training workers. Workers will upload their checkpoint, then continue training immediately. If you need to synchronize workers, you can use a framework-native barrier such as torch.distributed.barrier().


import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer

def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...
           # torch.save(...)

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Example: Only the rank 0 worker uploads the checkpoint.
            if ray.train.get_context().get_world_rank() == 0:
                train.report(metrics, checkpoint=checkpoint)
                train.report(metrics, checkpoint=None)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.