ray.train.get_all_reported_checkpoints#
- ray.train.get_all_reported_checkpoints() List[ReportedCheckpoint] #
Get all the reported checkpoints so far.
Blocks until Ray Train has finished processing every in-flight
ray.train.report
call.Example
import tempfile import ray.train from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 for epoch in range(start_epoch, config.get("num_epochs", 2)): # Do training... metrics = {"loss": 0.1} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir) ray.train.report(metrics, checkpoint=checkpoint) reported_checkpoints = ray.train.get_all_reported_checkpoints() # Report artifacts/metrics to experiment tracking framework... trainer = TorchTrainer( train_func, scaling_config=ray.train.ScalingConfig(num_workers=2) ) trainer.fit()
- Returns:
List of ReportedCheckpoint objects that represent the checkpoints and corresponding metrics reported by the workers.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.