Source code for ray.util.collective.backend_registry

from typing import Dict, Type

from .collective_group.base_collective_group import BaseGroup
from ray.util.annotations import PublicAPI


class BackendRegistry:
    _instance = None
    _map: Dict[str, Type[BaseGroup]]

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(BackendRegistry, cls).__new__(cls)
            cls._instance._map = {}
        return cls._instance

    def put(self, name: str, group_cls: Type[BaseGroup]) -> None:
        if not issubclass(group_cls, BaseGroup):
            raise TypeError(f"{group_cls} is not a subclass of BaseGroup")
        if name.upper() in self._map:
            raise ValueError(f"Backend {name.upper()} already registered")
        self._map[name.upper()] = group_cls

    def get(self, name: str) -> Type[BaseGroup]:
        name = name.upper()
        if name not in self._map:
            raise ValueError(f"Backend {name} not registered")
        return self._map[name]

    def is_registered(self, name: str) -> bool:
        """Check if a backend is registered (regardless of availability)."""
        return name.upper() in self._map

    def check(self, name: str) -> bool:
        """Check if a backend is both registered and available."""
        try:
            cls = self.get(name)
            return cls.check_backend_availability()
        except (ValueError, AttributeError):
            return False


_global_registry = BackendRegistry()


[docs] @PublicAPI(stability="alpha") def register_collective_backend(name: str, group_cls: Type[BaseGroup]): """Register a custom collective backend with Ray. This function registers a custom backend class that can be used for collective operations. The backend must be a subclass of :class:`~ray.util.collective.collective_group.base_collective_group.BaseGroup` and implement all required collective operations. Important: The backend must be registered on both the driver and all actors before creating collective groups. This is because each process (driver and each actor) needs to know about your backend class to instantiate it. Args: name: The name of the backend (e.g., "MY_BACKEND"). This will be automatically added to the Backend enum as Backend.MY_BACKEND. group_cls: The backend class, which must be a subclass of :class:`~ray.util.collective.collective_group.base_collective_group.BaseGroup`. Example: >>> import ray >>> from ray.util.collective import create_collective_group, init_collective_group >>> from ray.util.collective.backend_registry import register_collective_backend >>> from ray.util.collective.collective_group.base_collective_group import BaseGroup >>> >>> class MyCustomBackend(BaseGroup): ... def __init__(self, world_size, rank, group_name): ... super().__init__(world_size, rank, group_name) ... @classmethod ... def backend(cls): ... return "MY_BACKEND" ... @classmethod ... def check_backend_availability(cls) -> bool: ... return True ... def allreduce(self, tensor, allreduce_options=None): ... pass ... def broadcast(self, tensor, broadcast_options=None): ... pass ... def barrier(self, barrier_options=None): ... pass >>> >>> # Register on the driver >>> register_collective_backend("MY_BACKEND", MyCustomBackend) >>> >>> ray.init() >>> >>> @ray.remote ... class Worker: ... def __init__(self, rank): ... self.rank = rank ... def setup(self, world_size): ... # IMPORTANT: Register on each worker too ... register_collective_backend("MY_BACKEND", MyCustomBackend) ... init_collective_group( ... world_size=world_size, ... rank=self.rank, ... backend="MY_BACKEND", ... group_name="default", ... ) >>> >>> actors = [Worker.remote(rank=i) for i in range(2)] >>> create_collective_group( ... actors=actors, ... world_size=2, ... ranks=[0, 1], ... backend="MY_BACKEND", ... group_name="default", ... ) >>> ray.get([a.setup.remote(2) for a in actors]) """ _global_registry.put(name, group_cls) from . import types upper_name = name.upper() if not hasattr(types.Backend, upper_name): setattr(types.Backend, upper_name, upper_name)