ray.train.report#

ray.train.report(metrics: Dict, *, checkpoint: Checkpoint | None = None) None#

Report metrics and optionally save a checkpoint.

If a checkpoint is provided, it will be persisted to storage.

If this is called in multiple distributed training workers:

Note

Each invocation of this method will automatically increment the underlying training_iteration number. The physical meaning of this “iteration” is defined by user depending on how often they call report. It does not necessarily map to one epoch.

Warning

All workers must call ray.train.report the same number of times so that Ray Train can properly synchronize the training state across workers. Otherwise, your training will hang.

Warning

This method does NOT act as a barrier for distributed training workers. Workers will upload their checkpoint, then continue training immediately. If you need to synchronize workers, you can use a framework-native barrier such as torch.distributed.barrier().

Example

import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state
            ...

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...
           # torch.save(...)

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Example: Only the rank 0 worker uploads the checkpoint.
            if ray.train.get_context().get_world_rank() == 0:
                train.report(metrics, checkpoint=checkpoint)
            else:
                train.report(metrics, checkpoint=None)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Parameters:
  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.