ray.train.report#

ray.train.report(metrics: Dict[str, Any], checkpoint: Checkpoint | None = None, checkpoint_dir_name: str | None = None, checkpoint_upload_mode: CheckpointUploadMode = CheckpointUploadMode.SYNC, delete_local_checkpoint_after_upload: bool | None = None, checkpoint_upload_fn: Callable[[Checkpoint, str], Checkpoint] | None = None, validate_fn: Callable[[Checkpoint, Dict | None], Dict] | None = None, validate_config: Dict | None = None)#

Report metrics and optionally save a checkpoint.

If a checkpoint is provided, it will be persisted to storage.

If this is called in multiple distributed training workers:

  • Only the metrics reported by the rank 0 worker will be attached to the checkpoint.

  • A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the checkpointing guide.

  • Checkpoints from multiple workers will be merged into one directory in persistent storage. See the distributed checkpointing guide.

Warning

All workers must call ray.train.report the same number of times so that Ray Train can properly synchronize the training state across workers. This method acts as a barrier across all workers, so be sure that every worker reaches this method.

Example

import tempfile

import ray.train
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...
           # torch.save(...)

            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

            # Example: Only the rank 0 worker uploads the checkpoint.
            if ray.train.get_context().get_world_rank() == 0:
                ray.train.report(metrics, checkpoint=checkpoint)
            else:
                ray.train.report(metrics, checkpoint=None)

trainer = TorchTrainer(
    train_func, scaling_config=ray.train.ScalingConfig(num_workers=2)
)
Parameters:
  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.

  • checkpoint_dir_name – Custom name for the checkpoint directory. If not provided, a unique directory name will be automatically generated. If provided, it must be unique across all checkpoints per worker to avoid naming collisions. Consider including identifiers such as the epoch or batch index in the name.

  • checkpoint_upload_mode – The manner in which we want to upload the checkpoint. Defaults to uploading the checkpoint synchronously. This works when no checkpoint is provided but is not useful in that case.

  • delete_local_checkpoint_after_upload – Whether to delete the checkpoint after it is uploaded.

  • checkpoint_upload_fn – A user defined function that will be called with the checkpoint to upload it. If not provided, defaults to using the pyarrow.fs.copy_files utility for copying to the destination storage_path.

  • validate_fn – If provided, Ray Train will validate the checkpoint using this function.

  • validate_config – Configuration passed to the validate_fn. Can contain info like the validation dataset.