Source code for ray.train.rl.rl_checkpoint
import os
from packaging import version
from typing import Optional
from ray.air.checkpoint import Checkpoint
import ray.cloudpickle as cpickle
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.checkpoints import get_checkpoint_info
from ray.rllib.utils.typing import EnvType
from ray.util.annotations import PublicAPI
RL_TRAINER_CLASS_FILE = "trainer_class.pkl"
RL_CONFIG_FILE = "config.pkl"
[docs]@PublicAPI(stability="alpha")
class RLCheckpoint(Checkpoint):
"""A :py:class:`~ray.air.checkpoint.Checkpoint` with RLlib-specific
functionality.
Create this from a generic :py:class:`~ray.air.checkpoint.Checkpoint` by calling
``RLCheckpoint.from_checkpoint(ckpt)``.
"""
[docs] def get_policy(self, env: Optional[EnvType] = None) -> Policy:
"""Retrieve the policy stored in this checkpoint.
Args:
env: Optional environment to instantiate the trainer with. If not given,
it is parsed from the saved trainer configuration.
Returns:
The policy stored in this checkpoint.
"""
# TODO: Deprecate this RLCheckpoint class (or move all our
# Algorithm/Policy.from_checkpoint utils into here).
# If newer checkpoint version -> Use `Policy.from_checkpoint()` util.
checkpoint_info = get_checkpoint_info(checkpoint=self)
if checkpoint_info["checkpoint_version"] > version.Version("0.1"):
# Since we have an Algorithm checkpoint, will extract all policies in that
# Algorithm -> need to index into "default_policy" in the returned dict.
return Policy.from_checkpoint(checkpoint=self)["default_policy"]
# Older checkpoint version.
with self.as_directory() as checkpoint_path:
trainer_class_path = os.path.join(checkpoint_path, RL_TRAINER_CLASS_FILE)
config_path = os.path.join(checkpoint_path, RL_CONFIG_FILE)
if not os.path.exists(trainer_class_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_TRAINER_CLASS_FILE}` file."
)
if not os.path.exists(config_path):
raise ValueError(
f"RLPredictor only works with checkpoints created by "
f"RLTrainer. The checkpoint you specified is missing the "
f"`{RL_CONFIG_FILE}` file."
)
with open(trainer_class_path, "rb") as fp:
trainer_cls = cpickle.load(fp)
with open(config_path, "rb") as fp:
config = cpickle.load(fp)
checkpoint_data_path = None
for file in os.listdir(checkpoint_path):
if file.startswith("checkpoint") and not file.endswith(
".tune_metadata"
):
checkpoint_data_path = os.path.join(checkpoint_path, file)
if not checkpoint_data_path:
raise ValueError(
f"Could not find checkpoint data in RLlib checkpoint. "
f"Found files: {list(os.listdir(checkpoint_path))}"
)
config.get("evaluation_config", {}).pop("in_evaluation", None)
trainer = trainer_cls(config=config, env=env)
trainer.restore(checkpoint_data_path)
return trainer.get_policy()