import atexit
import logging
from functools import partial
from types import FunctionType
from typing import Callable, Optional, Type, Union
import ray
import ray.cloudpickle as pickle
from ray.experimental.internal_kv import (
_internal_kv_del,
_internal_kv_get,
_internal_kv_initialized,
_internal_kv_put,
)
from ray.tune.error import TuneError
from ray.util.annotations import DeveloperAPI
TRAINABLE_CLASS = "trainable_class"
ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
RLLIB_ACTION_DIST = "rllib_action_dist"
RLLIB_INPUT = "rllib_input"
RLLIB_CONNECTOR = "rllib_connector"
TEST = "__test__"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS,
ENV_CREATOR,
RLLIB_MODEL,
RLLIB_PREPROCESSOR,
RLLIB_ACTION_DIST,
RLLIB_INPUT,
RLLIB_CONNECTOR,
TEST,
]
logger = logging.getLogger(__name__)
def _has_trainable(trainable_name):
return _global_registry.contains(TRAINABLE_CLASS, trainable_name)
@DeveloperAPI
def get_trainable_cls(trainable_name):
validate_trainable(trainable_name)
return _global_registry.get(TRAINABLE_CLASS, trainable_name)
@DeveloperAPI
def validate_trainable(trainable_name: str):
if not _has_trainable(trainable_name) and not _has_rllib_trainable(trainable_name):
raise TuneError(f"Unknown trainable: {trainable_name}")
def _has_rllib_trainable(trainable_name: str) -> bool:
try:
# Make sure everything rllib-related is registered.
from ray.rllib import _register_all
except (ImportError, ModuleNotFoundError):
return False
_register_all()
return _has_trainable(trainable_name)
@DeveloperAPI
def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool:
"""Check if a given trainable is a function trainable.
Either the trainable has been wrapped as a FunctionTrainable class already,
or it's still a FunctionType/partial/callable."""
from ray.tune.trainable import FunctionTrainable
if isinstance(trainable, str):
trainable = get_trainable_cls(trainable)
is_wrapped_func = isinstance(trainable, type) and issubclass(
trainable, FunctionTrainable
)
return is_wrapped_func or (
not isinstance(trainable, type)
and (
isinstance(trainable, FunctionType)
or isinstance(trainable, partial)
or callable(trainable)
)
)
[docs]@DeveloperAPI
def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True):
"""Register a trainable function or class.
This enables a class or function to be accessed on every Ray process
in the cluster.
Args:
name: Name to register.
trainable: Function or tune.Trainable class. Functions must
take (config, status_reporter) as arguments and will be
automatically converted into a class during registration.
"""
from ray.tune.trainable import Trainable, wrap_function
if isinstance(trainable, type):
logger.debug("Detected class for trainable.")
elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
logger.debug("Detected function for trainable.")
trainable = wrap_function(trainable)
elif callable(trainable):
logger.info("Detected unknown callable for trainable. Converting to class.")
trainable = wrap_function(trainable)
if not issubclass(trainable, Trainable):
raise TypeError("Second argument must be convertable to Trainable", trainable)
_global_registry.register(TRAINABLE_CLASS, name, trainable)
def _unregister_trainables():
_global_registry.unregister_all(TRAINABLE_CLASS)
[docs]@DeveloperAPI
def register_env(name: str, env_creator: Callable):
"""Register a custom environment for use with RLlib.
This enables the environment to be accessed on every Ray process
in the cluster.
Args:
name: Name to register.
env_creator: Callable that creates an env.
"""
if not callable(env_creator):
raise TypeError("Second argument must be callable.", env_creator)
_global_registry.register(ENV_CREATOR, name, env_creator)
def _unregister_envs():
_global_registry.unregister_all(ENV_CREATOR)
@DeveloperAPI
def register_input(name: str, input_creator: Callable):
"""Register a custom input api for RLlib.
Args:
name: Name to register.
input_creator: Callable that creates an
input reader.
"""
if not callable(input_creator):
raise TypeError("Second argument must be callable.", input_creator)
_global_registry.register(RLLIB_INPUT, name, input_creator)
def _unregister_inputs():
_global_registry.unregister_all(RLLIB_INPUT)
@DeveloperAPI
def registry_contains_input(name: str) -> bool:
return _global_registry.contains(RLLIB_INPUT, name)
@DeveloperAPI
def registry_get_input(name: str) -> Callable:
return _global_registry.get(RLLIB_INPUT, name)
def _unregister_all():
_unregister_inputs()
_unregister_envs()
_unregister_trainables()
def _check_serializability(key, value):
_global_registry.register(TEST, key, value)
def _make_key(prefix: str, category: str, key: str):
"""Generate a binary key for the given category and key.
Args:
prefix: Prefix
category: The category of the item
key: The unique identifier for the item
Returns:
The key to use for storing a the value.
"""
return (
b"TuneRegistry:"
+ prefix.encode("ascii")
+ b":"
+ category.encode("ascii")
+ b"/"
+ key.encode("ascii")
)
class _Registry:
def __init__(self, prefix: Optional[str] = None):
"""If no prefix is given, use runtime context job ID."""
self._to_flush = {}
self._prefix = prefix
self._registered = set()
self._atexit_handler_registered = False
@property
def prefix(self):
if not self._prefix:
self._prefix = ray.get_runtime_context().get_job_id()
return self._prefix
def _register_atexit(self):
if self._atexit_handler_registered:
# Already registered
return
if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE:
# Only cleanup on the driver
return
atexit.register(_unregister_all)
self._atexit_handler_registered = True
def register(self, category, key, value):
"""Registers the value with the global registry.
Raises:
PicklingError if unable to pickle to provided file.
"""
if category not in KNOWN_CATEGORIES:
from ray.tune import TuneError
raise TuneError(
"Unknown category {} not among {}".format(category, KNOWN_CATEGORIES)
)
self._to_flush[(category, key)] = pickle.dumps_debug(value)
if _internal_kv_initialized():
self.flush_values()
def unregister(self, category, key):
if _internal_kv_initialized():
_internal_kv_del(_make_key(self.prefix, category, key))
else:
self._to_flush.pop((category, key), None)
def unregister_all(self, category: Optional[str] = None):
remaining = set()
for cat, key in self._registered:
if category and category == cat:
self.unregister(cat, key)
else:
remaining.add((cat, key))
self._registered = remaining
def contains(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(self.prefix, category, key))
return value is not None
else:
return (category, key) in self._to_flush
def get(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(self.prefix, category, key))
if value is None:
raise ValueError(
"Registry value for {}/{} doesn't exist.".format(category, key)
)
return pickle.loads(value)
else:
return pickle.loads(self._to_flush[(category, key)])
def flush_values(self):
self._register_atexit()
for (category, key), value in self._to_flush.items():
_internal_kv_put(
_make_key(self.prefix, category, key), value, overwrite=True
)
self._registered.add((category, key))
self._to_flush.clear()
_global_registry = _Registry()
ray._private.worker._post_init_hooks.append(_global_registry.flush_values)
class _ParameterRegistry:
def __init__(self):
self.to_flush = {}
self.references = {}
def put(self, k, v):
self.to_flush[k] = v
if ray.is_initialized():
self.flush()
def get(self, k):
if not ray.is_initialized():
return self.to_flush[k]
return ray.get(self.references[k])
def flush(self):
for k, v in self.to_flush.items():
if isinstance(v, ray.ObjectRef):
self.references[k] = v
else:
self.references[k] = ray.put(v)
self.to_flush.clear()