ray.train.get_checkpoint#
- ray.train.get_checkpoint() Checkpoint | None #
Access the latest reported checkpoint to resume from if one exists.
Example
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) train.report(metrics, checkpoint=checkpoint) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- Returns:
- Checkpoint object if the session is currently being resumed.
Otherwise, return None.