ray.train.get_all_reported_checkpoints#
- ray.train.get_all_reported_checkpoints() List[ReportedCheckpoint] #
Get all the reported checkpoints so far.
Blocks until Ray Train has finished processing every
ray.train.report
call.Example
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) train.report(metrics, checkpoint=checkpoint) reported_checkpoints = train.get_all_reported_checkpoints() # Report artifacts/metrics to experiment tracking framework... trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- Returns:
List of ReportedCheckpoint objects that represent the checkpoints and corresponding metrics reported by the workers.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.