ray.train.get_checkpoint() Optional[ray.train.Checkpoint]#

Access the session’s last checkpoint to resume from if applicable.


import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer

def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            train.report(metrics, checkpoint=checkpoint)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)

Checkpoint object if the session is currently being resumed.

Otherwise, return None.