ray.tune.Trainable.load_checkpoint#

Trainable.load_checkpoint(checkpoint: Dict | None)[source]#

Subclasses should override this to implement restore().

Warning

In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the examples below.

Example

>>> import os
>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        my_checkpoint_path = os.path.join(checkpoint_path, "my/path")
...        return my_checkpoint_path
...    def load_checkpoint(self, my_checkpoint_path):
...        print(my_checkpoint_path)
>>> trainer = Example()
>>> # This is used when PAUSED.
>>> checkpoint_result = trainer.save() 
>>> trainer.restore(checkpoint_result) 

If Trainable.save_checkpoint returned a dict, then Tune will directly pass the dict data as the argument to this method.

Example

>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        return {"my_data": 1}
...    def load_checkpoint(self, checkpoint_dict):
...        print(checkpoint_dict["my_data"])

New in version 0.8.7.

Parameters:

checkpoint – If dict, the return value is as returned by save_checkpoint. Otherwise, the directory the checkpoint was stored in.