RaySGD: Distributed Training Wrappers

RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training.

The main features are:

  • Ease of use: Scale PyTorch’s native DistributedDataParallel and TensorFlow’s tf.distribute.MirroredStrategy without needing to monitor individual nodes.

  • Composability: RaySGD is built on top of the Ray Actor API, enabling seamless integration with existing Ray applications such as RLlib, Tune, and Ray.Serve.

  • Scale up and down: Start on single CPU. Scale up to multi-node, multi-CPU, or multi-GPU clusters by changing 2 lines of code.

Note

This API is new and may be revised in future Ray releases. If you encounter any bugs, please file an issue on GitHub.

Getting Started

You can start a TorchTrainer with the following:

import ray
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch.examples.train_example import LinearDataset

import torch
from torch.utils.data import DataLoader


def model_creator(config):
    return torch.nn.Linear(1, 1)


def optimizer_creator(model, config):
    """Returns optimizer."""
    return torch.optim.SGD(model.parameters(), lr=1e-2)


def data_creator(config):
    train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
    val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
    return train_loader, val_loader

ray.init()

trainer1 = TorchTrainer(
    model_creator=model_creator,
    data_creator=data_creator,
    optimizer_creator=optimizer_creator,
    loss_creator=torch.nn.MSELoss,
    num_workers=2,
    use_gpu=False,
    config={"batch_size": 64})

stats = trainer1.train()
print(stats)
trainer1.shutdown()
print("success!")

Tip

Get in touch with us if you’re using or considering using RaySGD!