Overlap communication and computation#

Compiled Graph currently provides experimental support for GPU communication and computation overlap. When you turn this feature on, it automatically overlaps the GPU communication with computation operations, thereby hiding the communication overhead and improving performance.

To enable this feature, specify _overlap_gpu_communication=True when calling dag.experimental_compile().

The following code has GPU communication and computation operations that benefit from overlapping.

import ray
import time
import torch
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.air._internal import torch_utils

@ray.remote(num_cpus=0, num_gpus=1)
class TorchTensorWorker:
    def __init__(self):
        self.device = torch_utils.get_devices()[0]

    def send(self, shape, dtype, value: int, send_tensor=True):
        if not send_tensor:
            return 1
        return torch.ones(shape, dtype=dtype, device=self.device) * value

    def recv_and_matmul(self, two_d_tensor):
        """
        Receive the tensor and do some expensive computation (matmul).

        Args:
            two_d_tensor: a 2D tensor that has the same size for its dimensions
        """
        # Check that tensor got loaded to the correct device.
        assert two_d_tensor.dim() == 2
        assert two_d_tensor.size(0) == two_d_tensor.size(1)
        assert two_d_tensor.device == self.device
        torch.matmul(two_d_tensor, two_d_tensor)
        return (two_d_tensor[0][0].item(), two_d_tensor.shape, two_d_tensor.dtype)

def test(overlap_gpu_communication):
    num_senders = 3
    senders = [TorchTensorWorker.remote() for _ in range(num_senders)]
    receiver = TorchTensorWorker.remote()

    shape = (10000, 10000)
    dtype = torch.float16

    with InputNode() as inp:
        branches = [sender.send.bind(shape, dtype, inp) for sender in senders]
        branches = [
            branch.with_type_hint(
                TorchTensorType(
                    transport="nccl", _static_shape=True, _direct_return=True
                )
            )
            for branch in branches
        ]
        branches = [receiver.recv_and_matmul.bind(branch) for branch in branches]
        dag = MultiOutputNode(branches)

    compiled_dag = dag.experimental_compile(
        _overlap_gpu_communication=overlap_gpu_communication
    )

    start = time.monotonic()
    for i in range(5):
        ref = compiled_dag.execute(i)
        result = ray.get(ref)
        assert result == [(i, shape, dtype)] * num_senders
    duration = time.monotonic() - start
    print(f"{overlap_gpu_communication=}, {duration=}")

if __name__ == "__main__":
    for overlap_gpu_communication in [False, True]:
        test(overlap_gpu_communication)

The output of the preceding code includes the following two lines:

overlap_gpu_communication=False, duration=1.0670117866247892
overlap_gpu_communication=True, duration=0.9211348341777921

The actual performance numbers may vary on different hardware, but enabling _overlap_gpu_communication improves latency by about 14% for this example.

To verify that Compiled Graph overlaps the communication and computation operations, visualize the execution schedule by setting the environment variable RAY_CGRAPH_VISUALIZE_SCHEDULE=1.

Execution Schedule with GPU Communication Overlap Enabled

Red nodes denote the operations with different execution orders in the optimized schedule compared to the original order, due to _overlap_gpu_communication.