Get Started with Distributed Training using PyTorch#

This tutorial walks through the process of converting an existing PyTorch script to use Ray Train.

Learn how to:

  1. Configure a model to run distributed and on the correct CPU/GPU device.

  2. Configure a dataloader to shard data across the workers and place data on the correct CPU or GPU device.

  3. Configure a training function to report metrics and save checkpoints.

  4. Configure scaling and CPU or GPU resource requirements for a training job.

  5. Launch a distributed training job with a TorchTrainer class.

Quickstart#

For reference, the final code will look something like the following:

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

def train_func():
    # Your PyTorch training code here.
    ...

scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
  1. train_func is the Python code that executes on each distributed training worker.

  2. ScalingConfig defines the number of distributed training workers and whether to use GPUs.

  3. TorchTrainer launches the distributed training job.

Compare a PyTorch training script with and without Ray Train.

import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

# Model, Loss, Optimizer
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
    1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
model.to("cuda")
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
for epoch in range(10):
    for images, labels in train_loader:
        images, labels = images.to("cuda"), labels.to("cuda")
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    metrics = {"loss": loss.item(), "epoch": epoch}
    checkpoint_dir = tempfile.mkdtemp()
    checkpoint_path = os.path.join(checkpoint_dir, "model.pt")
    torch.save(model.state_dict(), checkpoint_path)
    print(metrics)
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func():
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    # model.to("cuda")  # This is done by `prepare_model`
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()

# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

Set up a training function#

First, update your training code to support distributed training. Begin by wrapping your code in a training function:

def train_func():
    # Your model training code here.
    ...

Each distributed training worker executes this function.

You can also specify the input argument for train_func as a dictionary via the Trainer’s train_loop_config. For example:

def train_func(config):
    lr = config["lr"]
    num_epochs = config["num_epochs"]

config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

Warning

Avoid passing large data objects through train_loop_config to reduce the serialization and deserialization overhead. Instead, it’s preferred to initialize large objects (e.g. datasets, models) directly in train_func.

 def load_dataset():
     # Return a large in-memory dataset
     ...

 def load_model():
     # Return a large in-memory model instance
     ...

-config = {"data": load_dataset(), "model": load_model()}

 def train_func(config):
-    data = config["data"]
-    model = config["model"]

+    data = load_dataset()
+    model = load_model()
     ...

 trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

Set up a model#

Use the ray.train.torch.prepare_model() utility function to:

  1. Move your model to the correct device.

  2. Wrap it in DistributedDataParallel.

-from torch.nn.parallel import DistributedDataParallel
+import ray.train.torch

 def train_func():

     ...

     # Create model.
     model = ...

     # Set up distributed training and device placement.
-    device_id = ... # Your logic to get the right device.
-    model = model.to(device_id or "cpu")
-    model = DistributedDataParallel(model, device_ids=[device_id])
+    model = ray.train.torch.prepare_model(model)

     ...

Set up a dataset#

Use the ray.train.torch.prepare_data_loader() utility function, which:

  1. Adds a DistributedSampler to your DataLoader.

  2. Moves the batches to the right device.

Note that this step isn’t necessary if you’re passing in Ray Data to your Trainer. See Data Loading and Preprocessing.

 from torch.utils.data import DataLoader
+import ray.train.torch

 def train_func():

     ...

     dataset = ...

     data_loader = DataLoader(dataset, batch_size=worker_batch_size, shuffle=True)
+    data_loader = ray.train.torch.prepare_data_loader(data_loader)

     for epoch in range(10):
+        if ray.train.get_context().get_world_size() > 1:
+            data_loader.sampler.set_epoch(epoch)

         for X, y in data_loader:
-            X = X.to_device(device)
-            y = y.to_device(device)

     ...

Tip

Keep in mind that DataLoader takes in a batch_size which is the batch size for each worker. The global batch size can be calculated from the worker batch size (and vice-versa) with the following equation:

global_batch_size = worker_batch_size * ray.train.get_context().get_world_size()

Note

If you already manually set up your DataLoader with a DistributedSampler, prepare_data_loader() will not add another one, and will respect the configuration of the existing sampler.

Note

DistributedSampler does not work with a DataLoader that wraps IterableDataset. If you want to work with an dataset iterator, consider using Ray Data instead of PyTorch DataLoader since it provides performant streaming data ingestion for large scale datasets.

See Data Loading and Preprocessing for more details.

Report checkpoints and metrics#

To monitor progress, you can report intermediate metrics and checkpoints using the ray.train.report() utility function.

+import os
+import tempfile

+import ray.train

 def train_func():

     ...

     with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        torch.save(
            model.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt")
        )

+       metrics = {"loss": loss.item()}  # Training/validation metrics.

        # Build a Ray Train checkpoint from a directory
+       checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)

        # Ray Train will automatically save the checkpoint to persistent storage,
        # so the local `temp_checkpoint_dir` can be safely cleaned up after.
+       ray.train.report(metrics=metrics, checkpoint=checkpoint)

     ...

For more details, see Monitoring and Logging Metrics and Saving and Loading Checkpoints.

Configure scale and GPUs#

Outside of your training function, create a ScalingConfig object to configure:

  1. num_workers - The number of distributed training worker processes.

  2. use_gpu - Whether each worker should use a GPU (or CPU).

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

For more details, see Configuring Scale and GPUs.

Configure persistent storage#

Create a RunConfig object to specify the path where results (including checkpoints and artifacts) will be saved.

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

Warning

Specifying a shared storage location (such as cloud storage or NFS) is optional for single-node clusters, but it is required for multi-node clusters. Using a local path will raise an error during checkpointing for multi-node clusters.

For more details, see Configuring Persistent Storage.

Launch a training job#

Tying this all together, you can now launch a distributed training job with a TorchTrainer.

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

Access training results#

After training completes, a Result object is returned which contains information about the training run, including the metrics and checkpoints reported during training.

result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path        # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

For more usage examples, see Inspecting Training Results.

Next steps#

After you have converted your PyTorch training script to use Ray Train:

  • See User Guides to learn more about how to perform specific tasks.

  • Browse the Examples for end-to-end examples of how to use Ray Train.

  • Dive into the API Reference for more details on the classes and methods used in this tutorial.