ray.air.session.get_checkpoint
ray.air.session.get_checkpoint#
- ray.air.session.get_checkpoint() Optional[ray.air.checkpoint.Checkpoint] [source]#
Access the session’s last checkpoint to resume from if applicable.
- Returns
- Checkpoint object if the session is currently being resumed.
Otherwise, return None.
######## Using it in the *per worker* train loop (TrainSession) ###### from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig def train_func(): ckpt = session.get_checkpoint() if ckpt: with ckpt.as_directory() as loaded_checkpoint_dir: import tensorflow as tf model = tf.keras.models.load_model(loaded_checkpoint_dir) else: model = build_model() model.save("my_model", overwrite=True) session.report( metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model") ) scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() # trainer2 will pick up from the checkpoint saved by trainer1. trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, # this is ultimately what is accessed through # ``Session.get_checkpoint()`` resume_from_checkpoint=result.checkpoint, ) result2 = trainer2.fit()