ray.train.report#
- ray.train.report(metrics: Dict, *, checkpoint: Checkpoint | None = None) None #
Report metrics and optionally save a checkpoint.
If a checkpoint is provided, it will be persisted to storage.
If this is called in multiple distributed training workers:
Only the metrics reported by the rank 0 worker will be tracked by Ray Train. See the metrics logging guide.
A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the checkpointing guide.
Checkpoints from multiple workers will be merged into one directory in persistent storage. See the distributed checkpointing guide.
Note
Each invocation of this method will automatically increment the underlying
training_iteration
number. The physical meaning of this “iteration” is defined by user depending on how often they callreport
. It does not necessarily map to one epoch.Warning
All workers must call
ray.train.report
the same number of times so that Ray Train can properly synchronize the training state across workers. Otherwise, your training will hang.Warning
This method does NOT act as a barrier for distributed training workers. Workers will upload their checkpoint, then continue training immediately. If you need to synchronize workers, you can use a framework-native barrier such as
torch.distributed.barrier()
.Example
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... # torch.save(...) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Example: Only the rank 0 worker uploads the checkpoint. if ray.train.get_context().get_world_rank() == 0: train.report(metrics, checkpoint=checkpoint) else: train.report(metrics, checkpoint=None) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- Parameters:
metrics – The metrics you want to report.
checkpoint – The optional checkpoint you want to report.