ray.train.get_all_reported_checkpoints#

ray.train.get_all_reported_checkpoints() List[ReportedCheckpoint]#

Get all the reported checkpoints so far.

Blocks until Ray Train has finished processing every ray.train.report call.

Example

import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            train.report(metrics, checkpoint=checkpoint)

    reported_checkpoints = train.get_all_reported_checkpoints()
    # Report artifacts/metrics to experiment tracking framework...

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Returns:

List of ReportedCheckpoint objects that represent the checkpoints and corresponding metrics reported by the workers.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.