Torch Data Prefetching Benchmark for Ray Train#

We provide a benchmark example to show how the auto pipeline for host to device data transfer speeds up training on GPUs. This functionality can be easily enabled by setting auto_transfer=True in train.torch.prepare_data_loader().

from torch.utils.data import DataLoader
from ray import train

data_loader = DataLoader(my_dataset, batch_size)
train_loader = train.torch.prepare_data_loader(
    data_loader=train_loader, move_to_device=True, auto_transfer=True
)

Running the following command gives the runtime of a small model training with and without the auto pipeline functionality. The experiment size can be modified by setting different values for epochs and num_hidden_layers, e.g.,

python auto_pipeline_for_host_to_device_data_transfer.py --epochs 2 --num_hidden_layers 2

The table below displays the runtime in seconds (excluding preparation work) under different configurations. The first value in the parentheses reports the runtime of using the auto pipeline, and the second reports the time of not using it. These experiments were done on a NVIDIA 2080 Ti. The auto pipeline functionality offers more speed improvement when the model size and the number of epochs gets larger. (The actual runtime outputs may vary if these experiments are run locally or different hardware devices are used.)

epochs

num_of_layers

auto_transfer=False

auto_transfer=True

1

1

2.69

2.52

1

4

7.21

6.85

1

8

13.54

13.05

5

1

12.88

12.14

5

4

36.48

34.33

5

8

69.12

66.38

50

1

132.88

123.12

50

4

381.67

369.42

50

8

736.17

693.52

# The PyTorch data transfer benchmark script.
import argparse
import warnings

import numpy as np
import torch
import torch.nn as nn

import ray.train as train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer


class Net(nn.Module):
    def __init__(self, in_d, hidden):
        # output dim = 1
        super(Net, self).__init__()
        dims = [in_d] + hidden + [1]
        self.layers = nn.ModuleList(
            [nn.Linear(dims[i - 1], dims[i]) for i in range(len(dims))]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class BenchmarkDataset(torch.utils.data.Dataset):
    """Create a naive dataset for the benchmark"""

    def __init__(self, dim, size=1000):
        self.x = torch.from_numpy(np.random.normal(size=(size, dim))).float()
        self.y = torch.from_numpy(np.random.normal(size=(size, 1))).float()
        self.size = size

    def __getitem__(self, index):
        return self.x[index, None], self.y[index, None]

    def __len__(self):
        return self.size


def train_epoch(epoch, dataloader, model, loss_fn, optimizer):
    if train.get_context().get_world_size() > 1:
        dataloader.sampler.set_epoch(epoch)

    for X, y in dataloader:
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def train_func(config):
    data_size = config.get("data_size", 4096 * 50)
    batch_size = config.get("batch_size", 4096)
    hidden_size = config.get("hidden_size", 1)
    use_auto_transfer = config.get("use_auto_transfer", False)
    lr = config.get("lr", 1e-2)
    epochs = config.get("epochs", 10)

    train_dataset = BenchmarkDataset(4096, size=data_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )

    train_loader = train.torch.prepare_data_loader(
        data_loader=train_loader, move_to_device=True, auto_transfer=use_auto_transfer
    )

    model = Net(in_d=4096, hidden=[4096] * hidden_size)
    model = train.torch.prepare_model(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    choice = "with" if use_auto_transfer else "without"
    print(f"Starting the torch data prefetch benchmark {choice} auto pipeline...")

    torch.cuda.synchronize()
    start.record()
    for epoch in range(epochs):
        train_epoch(epoch, train_loader, model, loss_fn, optimizer)
    end.record()
    torch.cuda.synchronize()

    print(
        f"Finished the torch data prefetch benchmark {choice} "
        f"auto pipeline: {start.elapsed_time(end)} ms."
    )

    return "Experiment done."


def train_linear(num_workers=1, num_hidden_layers=1, use_auto_transfer=True, epochs=3):
    config = {
        "lr": 1e-2,
        "hidden_size": num_hidden_layers,
        "batch_size": 4096,
        "epochs": epochs,
        "use_auto_transfer": use_auto_transfer,
    }
    trainer = TorchTrainer(
        train_func,
        train_loop_config=config,
        scaling_config=ScalingConfig(use_gpu=True, num_workers=num_workers),
    )
    results = trainer.fit()

    print(results.metrics)
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--epochs", type=int, default=1, help="Number of epochs to train for."
    )
    parser.add_argument(
        "--num_hidden_layers",
        type=int,
        default=1,
        help="Number of epochs to train for.",
    )

    args, _ = parser.parse_known_args()

    import ray

    ray.init(address=args.address)

    if not torch.cuda.is_available():
        warnings.warn("GPU is not available. Skip the test using auto pipeline.")
    else:
        train_linear(
            num_workers=1,
            num_hidden_layers=args.num_hidden_layers,
            use_auto_transfer=True,
            epochs=args.epochs,
        )

    torch.cuda.empty_cache()
    train_linear(
        num_workers=1,
        num_hidden_layers=args.num_hidden_layers,
        use_auto_transfer=False,
        epochs=args.epochs,
    )

    ray.shutdown()