Implementing a custom tensor transport (Advanced)#

Ray Direct Transport (RDT) allows you to register custom tensor transports at runtime. This page explains how to implement a custom tensor transport by implementing the TensorTransportManager abstract interface.

Overview#

To create a custom tensor transport:

  1. Implement the abstract interface ray.experimental.TensorTransportManager.

  2. Define custom metadata classes by extending TensorTransportMetadata and CommunicatorMetadata.

  3. Register your transport using ray.experimental.register_tensor_transport.

When Ray needs to transfer a tensor between actors using your transport, it calls specific methods on your TensorTransportManager implementation at different stages of the transfer lifecycle.

Implementing TensorTransportManager#

The TensorTransportManager abstract class defines the interface for custom tensor transports. You must implement all abstract methods.

The following diagram shows when each method is called during a tensor transfer:

Source Actor                    Owner Process                 Destination Actor
============                    =============                 =================
     |                               |                               |
1. Task returns tensor               |                               |
   ``extract_tensor_transport_metadata``                             |
     |                               |                               |
     | ---- transport_metadata ----> |                               |
     |                               |                               |
     |                     2. Prepare communicator                   |
     |                        ``get_communicator_metadata``          |
     |                               |                               |
     | <---- comm metadata --------- | ---- comm metadata -------->  |
     |                               |                               |
3. ``send_multiple_tensors``         |          3. ``recv_multiple_tensors``
                                     |                               |
     | ------------ tensors ---------------------------------------> |
     |                               |                               |
     |                         (transfer complete)                   |
     |                               |                               |
     |                      5. Ref goes out of scope                 |
     | <---------------------------- |                               |
5. Clean up resources                |                               |
   ``garbage_collect``               |                               |

Note that Ray will not call send_multiple_tensors for one-sided transports. The following diagram shows where each method is called in the ray.put / ray.get case supported by one-sided transports.

Source Actor                                                  Destination Actor
============                                                  =================
     |                                                               |
1. User ``ray.put``'s tensor                                         |
   ``extract_tensor_transport_metadata``                             |
     |                                                               |
     |                                                               |
2. User passes ref to another actor                                  |
     | ---- transport_metadata ---------------------------------->   |
     |                                                               |
     |                                                               |
     |                                          3. User ``ray.get``'s on object ref
                                                   ``get_communicator_metadata``
     |                                              ``recv_multiple_tensors``
     | ------------ tensors --------- -----------------------------> |
     |                                                               |
     |                         (transfer complete)                   |
     |                                                               |
4. Clean up resources                                                |
   ``garbage_collect``                                               |
(when ref goes out of scope)                                         |

The API reference page for TensorTransportManager has more details on what each method does and how to implement them. See implementations of Ray’s default transports (NCCL, NIXL, etc.) in the python/ray/experimental/rdt/ directory. The following is an walk-through for implementing and using a custom tensor transport.

Example: Shared memory tensor transport#

The following walks through a complete custom tensor transport that transfers numpy arrays through shared memory.

Note that because shared memory is one-sided (the receiver directly reads the memory block the sender wrote to), is_one_sided returns True and Ray never calls send_multiple_tensors.

Define metadata classes#

Your transport uses two metadata classes that flow through different stages of the transfer:

  • TensorTransportMetadata is created on the source actor during extract_tensor_transport_metadata. It carries per-tensor information (shapes, dtypes, devices) plus any transport-specific identifiers (e.g., shared memory block names, RDMA keys) that the receiver needs to locate and read the data.

  • CommunicatorMetadata is created on the owner/driver process during get_communicator_metadata. It carries any coordination information both actors need, such as ranks in a collective group. For one-sided transports (where the receiver can directly read the sender’s memory), an empty metadata object is typically sufficient.

Start by extending these classes to carry any transport-specific state. ShmTransportMetadata stores the shared memory block name and size so the receiver can locate and read the data. This transport doesn’t need any communicator metadata, so ShmCommunicatorMetadata is empty.

import multiprocessing.shared_memory as shm
import pickle
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy

import ray
from ray.experimental import (
    CommunicatorMetadata,
    TensorTransportManager,
    TensorTransportMetadata,
    register_tensor_transport,
)


@dataclass
class ShmTransportMetadata(TensorTransportMetadata):
    """Custom metadata that stores the shared memory name and size."""

    shm_name: Optional[str] = None
    shm_size: Optional[int] = None


@dataclass
class ShmCommunicatorMetadata(CommunicatorMetadata):
    """No extra communicator metadata needed for shared memory."""

    pass


Extract tensor transport metadata#

Ray calls extract_tensor_transport_metadata on the source actor right after the task produces its result tensors. Record shapes and dtypes, then perform any transport-specific registration. Here, the implementation serializes the tensors into a shared memory block and records the block name and size in the metadata so the receiver can find it.

    def extract_tensor_transport_metadata(
        self,
        obj_id: str,
        rdt_object: List[numpy.ndarray],
    ) -> TensorTransportMetadata:
        # Record shapes and dtypes.
        tensor_meta = []
        if rdt_object:
            for tensor in rdt_object:
                tensor_meta.append((tensor.shape, tensor.dtype))

        # Serialize the tensors and store them in shared memory.
        serialized_rdt_object = pickle.dumps(rdt_object)
        size = len(serialized_rdt_object)
        name = obj_id[:20]
        shm_obj = shm.SharedMemory(name=name, create=True, size=size)
        shm_obj.buf[:size] = serialized_rdt_object
        self.shared_memory_objects[obj_id] = shm_obj

        return ShmTransportMetadata(
            tensor_meta=tensor_meta, tensor_device="cpu", shm_name=name, shm_size=size
        )

Get communicator metadata#

Ray calls get_communicator_metadata on the owner/driver process before orchestrating the transfer. Return any information both actors need to coordinate, such as ranks in a collective group. For one-sided transports such as shared memory, an empty metadata object is fine.

    def get_communicator_metadata(
        self,
        src_actor: "ray.actor.ActorHandle",
        dst_actor: "ray.actor.ActorHandle",
        backend: Optional[str] = None,
    ) -> CommunicatorMetadata:
        return ShmCommunicatorMetadata()

Transport properties#

Define your TensorTransportManager subclass and implement the property methods. tensor_transport_backend returns the name that users pass to @ray.method(tensor_transport=...). is_one_sided and can_abort_transport tell Ray how to orchestrate transfers and handle errors. actor_has_tensor_transport lets Ray check whether a given actor can use this transport.

class SharedMemoryTransport(TensorTransportManager):
    """A one-sided tensor transport that transfers numpy arrays through shared memory."""

    def __init__(self):
        self.shared_memory_objects: Dict[str, shm.SharedMemory] = {}

    def tensor_transport_backend(self) -> str:
        return "shared_memory"

    @staticmethod
    def is_one_sided() -> bool:
        return True

    @staticmethod
    def can_abort_transport() -> bool:
        return False

    def actor_has_tensor_transport(self, actor: "ray.actor.ActorHandle") -> bool:
        return True

Send and receive#

recv_multiple_tensors runs on the destination actor. For this shared memory transport, it opens the shared memory block by name and deserializes the tensors.

send_multiple_tensors runs on the source actor for two-sided transports. Since shared memory is one-sided, Ray never calls this method, so it raises NotImplementedError as a safety guard.

    def recv_multiple_tensors(
        self,
        obj_id: str,
        tensor_transport_metadata: TensorTransportMetadata,
        communicator_metadata: CommunicatorMetadata,
        target_buffers: Optional[List[Any]] = None,
    ):
        # Open the shared memory block and deserialize.
        shm_name = tensor_transport_metadata.shm_name
        size = tensor_transport_metadata.shm_size
        shm_block = shm.SharedMemory(name=shm_name)
        recv_tensors = pickle.loads(shm_block.buf[:size])
        shm_block.close()
        return recv_tensors

    def send_multiple_tensors(
        self,
        tensors: List[numpy.ndarray],
        tensor_transport_metadata: TensorTransportMetadata,
        communicator_metadata: CommunicatorMetadata,
    ):
        raise NotImplementedError("One-sided transport doesn't use send.")

Cleanup#

garbage_collect runs on the source actor when Ray’s reference counting determines the object is out of scope. Release any transport resources here, in this case closing and unlinking the shared memory block.

abort_transport runs on both actors when a system error occurs during transfer, if can_abort_transport returns True. Since this transport returns False for can_abort_transport, Ray kills the involved actors instead, so abort_transport is a no-op.

    def garbage_collect(
        self,
        obj_id: str,
        tensor_transport_meta: TensorTransportMetadata,
        tensors: List[numpy.ndarray],
    ):
        self.shared_memory_objects[obj_id].close()
        self.shared_memory_objects[obj_id].unlink()
        del self.shared_memory_objects[obj_id]

    def abort_transport(
        self,
        obj_id: str,
        communicator_metadata: CommunicatorMetadata,
    ):
        pass

Registering your transport#

After implementing your transport, the driver process must register it with ray.experimental.register_tensor_transport before creating any actors that use it:

register_tensor_transport(
    "shared_memory",        # Transport name
    ["cpu"],                # Supported device types
    SharedMemoryTransport,  # TensorTransportManager class
    numpy.ndarray,          # Data type for this transport
)


@ray.remote
class MyActor:
    @ray.method(tensor_transport="shared_memory")
    def echo(self, data):
        return data

    def sum(self, data):
        return data.sum().item()


actors = [MyActor.remote() for _ in range(2)]
ref = actors[0].echo.remote(numpy.array([1, 2, 3]))
result = actors[1].sum.remote(ref)
print(ray.get(result))
# 6

Limitations#

Custom tensor transports have the following limitations:

  • Actor restarts aren’t supported. Your actor doesn’t have access to the custom transport after a restart.

  • Register transports before actor creation. If you register a transport after creating an actor, that actor can’t use the new transport.

  • Out-of-order actors If you have an out-of-order actor (such as an async actor) and the process where you submit the actor task is different from where you created the actor, Ray can’t guarantee it has registered your custom transport on the actor at task execution time.

  • Actor creation and task submission from different processes If the process where you submit an actor task is different from where you created the actor, Ray can’t guarantee it has registered your custom transport on the actor at task execution time.

For general RDT limitations, see limitations.

Also feel free to reach out through GitHub issues or the Ray Slack to ask any questions.