ray.train.get_all_reported_checkpoints#

ray.train.get_all_reported_checkpoints() List[ReportedCheckpoint]#

Get all the reported checkpoints so far.

Blocks until Ray Train has finished processing every in-flight ray.train.report call.

Example

import tempfile

import ray.train
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0

    for epoch in range(start_epoch, config.get("num_epochs", 2)):
        # Do training...

        metrics = {"loss": 0.1}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...

            checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
            ray.train.report(metrics, checkpoint=checkpoint)

    reported_checkpoints = ray.train.get_all_reported_checkpoints()
    # Report artifacts/metrics to experiment tracking framework...

trainer = TorchTrainer(
    train_func, scaling_config=ray.train.ScalingConfig(num_workers=2)
)
trainer.fit()
Returns:

List of ReportedCheckpoint objects that represent the checkpoints and corresponding metrics reported by the workers.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.