Source code for ray.experimental.collective.collective

import threading
import uuid
from typing import Dict, List, Optional, Union

import ray
import ray.experimental.internal_kv as internal_kv
from ray.experimental.collective.communicator import CommunicatorHandle
from ray.experimental.collective.util import get_address_and_port
from ray.util.annotations import PublicAPI
from ray.util.collective.collective_group.torch_gloo_collective_group import (
    get_master_address_metadata_key,
)
from ray.util.collective.types import Backend

_remote_communicator_manager: "Optional[RemoteCommunicatorManager]" = None
_remote_communicator_manager_lock = threading.Lock()


class RemoteCommunicatorManager:
    """Singleton class to store the mapping between actors and communicators
    that the actors are a part of.
    """

    def __init__(self):
        # Handles to communicators that we created. Key is a user-provided
        # name or UUID.
        self._remote_communicators: Dict[str, CommunicatorHandle] = {}

    @staticmethod
    def get() -> "RemoteCommunicatorManager":
        global _remote_communicator_manager
        with _remote_communicator_manager_lock:
            if _remote_communicator_manager is None:
                _remote_communicator_manager = RemoteCommunicatorManager()
            return _remote_communicator_manager

    def add_remote_communicator(self, comm_handle: CommunicatorHandle):
        self._remote_communicators[comm_handle.name] = comm_handle

    def remove_remote_communicator(self, name: str):
        return self._remote_communicators.pop(name, None)

    def get_collective_groups(
        self,
        actors: Optional[List[ray.actor.ActorHandle]] = None,
        backend: Optional[str] = None,
    ):
        """
        Get the collective groups that the given actors are a subset of. Filter by
        backend if provided.
        """
        actors = actors or []
        actors = set(actors)

        collectives = []
        # Find all collective groups that the given actors are a subset
        # of, with the matching backend if provided.
        for collective in self._remote_communicators.values():
            if actors.issubset(set(collective.actors)):
                if backend is None or collective.backend == backend:
                    collectives.append(collective)
        return collectives


def _do_init_collective_group(
    self,
    world_size: int,
    rank: int,
    backend: str = Backend.NCCL,
    name: str = "default",
):
    """Helper method that runs as a task on a remote actor to create a
    collective group.
    """
    ray.util.collective.init_collective_group(
        world_size, rank, backend, group_name=name
    )


def _do_destroy_collective_group(self, name):
    """Helper method that runs as a task on a remote actor to destroy a
    collective group.
    """
    ray.util.collective.destroy_collective_group(name)


[docs] @PublicAPI(stability="alpha") def get_collective_groups( actors: List[ray.actor.ActorHandle], backend: Optional[str] = None ) -> List[CommunicatorHandle]: """ Get the collective groups that the given actors are a subset of. Filter by backend if provided. Args: actors: List of actors. Return handles to all collective groups that these actors are a subset of. backend: An optional backend to filter by. See ray.util.collective.types.Backend for valid backends. Returns: A list of communicator handles that the actors are a subset of. """ manager = RemoteCommunicatorManager.get() return manager.get_collective_groups(actors, backend)
[docs] @PublicAPI(stability="alpha") def create_collective_group( actors: List[ray.actor.ActorHandle], backend: str, name: Optional[str] = None, ) -> CommunicatorHandle: """Create a collective group on the given list of actors. If this function returns successfully, then the collective group has been initialized on all actors, using the given order of actors as the ranks. Currently, an actor can only participate in one collective group per backend at a time. To reuse an actor, destroy its collective group and create a new one. Args: actors: The actors to participate in the collective group. backend: The backend to use. See ray.util.collective.types.Backend for valid backends. name: A name to use for the collective group. If None is provided, a random name will be generated. Returns: Handle to the communicator. """ manager = RemoteCommunicatorManager.get() if name is None: name = str(uuid.uuid4()) # Validate the backend. backend = Backend(backend) world_size = len(actors) for actor in actors: if manager.get_collective_groups([actor], backend): raise RuntimeError( f"Actor {actor} already in group for backend {backend}. Actors can currently only participate in at most one group per backend." ) actor_ids = [actor._ray_actor_id for actor in actors] if len(set(actor_ids)) != len(actor_ids): raise ValueError(f"All actors must be unique, got: {actors}") metadata_key = None if backend == Backend.TORCH_GLOO: # Perform extra setup for torch.distributed. # torch.distributed requires a master address and port. Find a suitable # port on one of the actors. master_addr, master_port = ray.get( actors[0].__ray_call__.remote(lambda self: get_address_and_port()) ) # Store the metadata on a named actor that all of the other # actors can access. metadata_key = get_master_address_metadata_key(name) internal_kv._internal_kv_put(metadata_key, f"{master_addr}:{master_port}") try: init_tasks = [ actor.__ray_call__.remote( _do_init_collective_group, world_size, rank, backend, name ) for rank, actor in enumerate(actors) ] ray.get(init_tasks) finally: # Clean up the metadata once collective group is initialized # (or failed to initialize). if metadata_key is not None: internal_kv._internal_kv_del(metadata_key) # Group was successfully created. comm = CommunicatorHandle(actors, name, backend) manager.add_remote_communicator(comm) return comm
[docs] @PublicAPI(stability="alpha") def destroy_collective_group(group_or_name: Union[CommunicatorHandle, str]): """ Destroy a collective group. If this functions returns successfully, then the actors that were in the collective can be reused to create a new collective group. Args: group_or_name: Either a communicator handle or the name of the group to destroy. """ if isinstance(group_or_name, CommunicatorHandle): name = group_or_name.name elif isinstance(group_or_name, str): name = group_or_name else: raise ValueError("Expected CommunicatorHandle or str (group name).") manager = RemoteCommunicatorManager.get() group = manager.remove_remote_communicator(name) if group is not None: destroy_tasks = [ actor.__ray_call__.remote(_do_destroy_collective_group, name) for actor in group.actors ] ray.get(destroy_tasks) else: raise ValueError(f"No group with name {name} found.")
@PublicAPI(stability="alpha") def destroy_all_collective_groups(): """ Destroy all collective groups. This will destroy all collective groups that were previously created by this process. After this function returns, the actors participating in those collective groups can be reused to create a new collective group. """ manager = RemoteCommunicatorManager.get() for collective in manager.get_collective_groups(): destroy_collective_group(collective.name)