Source code for ray.rllib.core.rl_module
import logging
import re
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
)
from ray.util import log_once
from ray.util.annotations import PublicAPI
logger = logging.getLogger("ray.rllib")
[docs]
@PublicAPI(stability="alpha")
def validate_module_id(policy_id: str, error: bool = False) -> None:
"""Makes sure the given `policy_id` is valid.
Args:
policy_id: The Policy ID to check.
IMPORTANT: Must not contain characters that
are also not allowed in Unix/Win filesystems, such as: `<>:"/\\|?*`
or a dot `.` or space ` ` at the end of the ID.
error: Whether to raise an error (ValueError) or a warning in case of an
invalid `policy_id`.
Raises:
ValueError: If the given `policy_id` is not a valid one and `error` is True.
"""
if (
not isinstance(policy_id, str)
or len(policy_id) == 0
or re.search('[<>:"/\\\\|?]', policy_id)
or policy_id[-1] in (" ", ".")
):
msg = (
f"PolicyID `{policy_id}` not valid! IDs must be a non-empty string, "
"must not contain characters that are also disallowed file- or directory "
"names on Unix/Windows and must not end with a dot `.` or a space ` `."
)
if error:
raise ValueError(msg)
elif log_once("invalid_policy_id"):
logger.warning(msg)
__all__ = [
"MultiRLModule",
"MultiRLModuleSpec",
"RLModule",
"RLModuleSpec",
"validate_module_id",
]