ray.air.session.get_checkpoint#

ray.air.session.get_checkpoint() Optional[ray.air.checkpoint.Checkpoint][source]#

Access the session’s last checkpoint to resume from if applicable.

Returns

Checkpoint object if the session is currently being resumed.

Otherwise, return None.

######## Using it in the *per worker* train loop (TrainSession) ######
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
def train_func():
    ckpt = session.get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as loaded_checkpoint_dir:
            import tensorflow as tf

            model = tf.keras.models.load_model(loaded_checkpoint_dir)
    else:
        model = build_model()

    model.save("my_model", overwrite=True)
    session.report(
        metrics={"iter": 1},
        checkpoint=Checkpoint.from_directory("my_model")
    )

scaling_config = ScalingConfig(num_workers=2)
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func, scaling_config=scaling_config
)
result = trainer.fit()

# trainer2 will pick up from the checkpoint saved by trainer1.
trainer2 = TensorflowTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
    # this is ultimately what is accessed through
    # ``Session.get_checkpoint()``
    resume_from_checkpoint=result.checkpoint,
)
result2 = trainer2.fit()