ray.train.report#

ray.train.report(metrics: Dict, *, checkpoint: Optional[ray.train.Checkpoint] = None) None#

Report metrics and optionally save a checkpoint.

Each invocation of this method will automatically increment the underlying training_iteration number. The physical meaning of this “iteration” is defined by user depending on how often they call report. It does not necessarily map to one epoch.

This method acts as a synchronous barrier for all distributed training workers. All workers must call ray.train.report the same number of times.

If a checkpoint is provided, it will be persisted to storage.

Example

import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state
            ...

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            train.report(metrics, checkpoint=checkpoint)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Parameters
  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.