ray.train.report
ray.train.report#
- ray.train.report(metrics: Dict, *, checkpoint: Optional[ray.train.Checkpoint] = None) None #
Report metrics and optionally save a checkpoint.
Each invocation of this method will automatically increment the underlying
training_iteration
number. The physical meaning of this “iteration” is defined by user depending on how often they callreport
. It does not necessarily map to one epoch.This method acts as a synchronous barrier for all distributed training workers. All workers must call
ray.train.report
the same number of times.If a checkpoint is provided, it will be persisted to storage.
Example
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) train.report(metrics, checkpoint=checkpoint) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- Parameters
metrics – The metrics you want to report.
checkpoint – The optional checkpoint you want to report.