RaySGD: Distributed Training Wrappers

Warning

This is an older version of Ray SGD. A newer, more light-weight version of Ray SGD (named Ray Train) is in alpha as of Ray 1.7. See the documentation here. To migrate from v1 to v2 you can follow the migration guide.

RaySGD is a lightweight library for distributed deep learning, providing thin wrappers around PyTorch and TensorFlow native modules for data parallel training.

The main features are:

  • Ease of use: Scale PyTorch’s native DistributedDataParallel and TensorFlow’s tf.distribute.MirroredStrategy without needing to monitor individual nodes.

  • Composability: RaySGD is built on top of the Ray Actor API, enabling seamless integration with existing Ray applications such as RLlib, Tune, and Ray.Serve.

  • Scale up and down: Start on single CPU. Scale up to multi-node, multi-CPU, or multi-GPU clusters by changing 2 lines of code.

Getting Started

You can start a TorchTrainer with the following:

import ray
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd.torch.examples.train_example import LinearDataset

import torch
from torch.utils.data import DataLoader

class CustomTrainingOperator(TrainingOperator):
    def setup(self, config):
        # Load data.
        train_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])
        val_loader = DataLoader(LinearDataset(2, 5), config["batch_size"])

        # Create model.
        model = torch.nn.Linear(1, 1)

        # Create optimizer.
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

        # Create loss.
        loss = torch.nn.MSELoss()

        # Register model, optimizer, and loss.
        self.model, self.optimizer, self.criterion = self.register(
            models=model,
            optimizers=optimizer,
            criterion=loss)

        # Register data loaders.
        self.register_data(train_loader=train_loader, validation_loader=val_loader)


ray.init()

trainer1 = TorchTrainer(
    training_operator_cls=CustomTrainingOperator,
    num_workers=2,
    use_gpu=False,
    config={"batch_size": 64})

stats = trainer1.train()
print(stats)
trainer1.shutdown()
print("success!")