ray.train.report#
- ray.train.report(metrics: Dict[str, Any], checkpoint: Checkpoint | None = None, checkpoint_dir_name: str | None = None)#
Report metrics and optionally save a checkpoint.
If a checkpoint is provided, it will be persisted to storage.
If this is called in multiple distributed training workers:
Only the metrics reported by the rank 0 worker will be attached to the checkpoint.
A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the checkpointing guide.
Checkpoints from multiple workers will be merged into one directory in persistent storage. See the distributed checkpointing guide.
Warning
All workers must call
ray.train.report
the same number of times so that Ray Train can properly synchronize the training state across workers. This method acts as a barrier across all workers, so be sure that every worker reaches this method.Example
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... # torch.save(...) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Example: Only the rank 0 worker uploads the checkpoint. if ray.train.get_context().get_world_rank() == 0: train.report(metrics, checkpoint=checkpoint) else: train.report(metrics, checkpoint=None) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- Parameters:
metrics – The metrics you want to report.
checkpoint – The optional checkpoint you want to report.