Source code for ray.train.torch.torch_trainer

from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.torch.config import TorchConfig
from ray.train.trainer import GenDataset
from ray.util import PublicAPI

if TYPE_CHECKING:
    from ray.data.preprocessor import Preprocessor


[docs]@PublicAPI(stability="beta") class TorchTrainer(DataParallelTrainer): """A Trainer for data parallel PyTorch training. This Trainer runs the function ``train_loop_per_worker`` on multiple Ray Actors. These actors already have the necessary torch process group configured for distributed PyTorch training. The ``train_loop_per_worker`` function is expected to take in either 0 or 1 arguments: .. testcode:: def train_loop_per_worker(): ... .. testcode:: from typing import Dict, Any def train_loop_per_worker(config: Dict[str, Any]): ... If ``train_loop_per_worker`` accepts an argument, then ``train_loop_config`` will be passed in as the argument. This is useful if you want to tune the values in ``train_loop_config`` as hyperparameters. If the ``datasets`` dict contains a training dataset (denoted by the "train" key), then it will be split into multiple dataset shards that can then be accessed by ``session.get_dataset_shard("train")`` inside ``train_loop_per_worker``. All the other datasets will not be split and ``session.get_dataset_shard(...)`` will return the the entire Dataset. Inside the ``train_loop_per_worker`` function, you can use any of the :ref:`Ray AIR session methods <air-session-ref>`. See full example code below. .. testcode:: def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. session.report(...) # Get dict of last saved checkpoint. session.get_checkpoint() # Session returns the Dataset shard for the given key. session.get_dataset_shard("my_dataset") # Get the total number of workers executing training. session.get_world_size() # Get the rank of this worker. session.get_world_rank() # Get the rank of the worker on the current node. session.get_local_rank() You can also use any of the Torch specific function utils, such as :func:`ray.train.torch.get_device` and :func:`ray.train.torch.prepare_model` .. testcode:: def train_loop_per_worker(): # Prepares model for distribted training by wrapping in # `DistributedDataParallel` and moving to correct device. train.torch.prepare_model(...) # Configures the dataloader for distributed training by adding a # `DistributedSampler`. # You should NOT use this if you are doing # `session.get_dataset_shard(...).iter_torch_batches(...)` train.torch.prepare_data_loader(...) # Get the current torch device. train.torch.get_device() Any returns from the ``train_loop_per_worker`` will be discarded and not used or persisted anywhere. To save a model to use for the ``TorchPredictor``, you must save it under the "model" kwarg in ``Checkpoint`` passed to ``session.report()``. .. note:: When you wrap the ``model`` with ``prepare_model``, the keys of its ``state_dict`` are prefixed by ``module.``. For example, ``layer1.0.bn1.bias`` becomes ``module.layer1.0.bn1.bias``. However, when saving ``model`` through ``session.report()`` all ``module.`` prefixes are stripped. As a result, when you load from a saved checkpoint, make sure that you first load ``state_dict`` to the model before calling ``prepare_model``. Otherwise, you will run into errors like ``Error(s) in loading state_dict for DistributedDataParallel: Missing key(s) in state_dict: "module.conv1.weight", ...``. See snippet below. .. testcode:: from torchvision.models import resnet18 from ray.air import session from ray.air.checkpoint import Checkpoint import ray.train as train def train_func(): ... model = resnet18() model = train.torch.prepare_model(model) for epoch in range(3): ... ckpt = Checkpoint.from_dict({ "epoch": epoch, "model": model.state_dict(), # "model": model.module.state_dict(), # ** The above two are equivalent ** }) session.report({"foo": "bar"}, ckpt) Example: .. code-block:: python import torch import torch.nn as nn import ray from ray import train from ray.air import session, Checkpoint from ray.train.torch import TorchTrainer from ray.air.config import ScalingConfig from ray.air.config import RunConfig from ray.air.config import CheckpointConfig # If using GPUs, set this to True. use_gpu = False # Define NN layers archicture, epochs, and number of workers input_size = 1 layer_size = 32 output_size = 1 num_epochs = 200 num_workers = 3 # Define your network structure class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.layer1 = nn.Linear(input_size, layer_size) self.relu = nn.ReLU() self.layer2 = nn.Linear(layer_size, output_size) def forward(self, input): return self.layer2(self.relu(self.layer1(input))) # Define your train worker loop def train_loop_per_worker(): # Fetch training set from the session dataset_shard = session.get_dataset_shard("train") model = NeuralNetwork() # Loss function, optimizer, prepare model for training. # This moves the data and prepares model for distributed # execution loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.01) model = train.torch.prepare_model(model) # Iterate over epochs and batches for epoch in range(num_epochs): for batches in dataset_shard.iter_torch_batches(batch_size=32, dtypes=torch.float): # Add batch or unsqueeze as an additional dimension [32, x] inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"] output = model(inputs) # Make output shape same as the as labels loss = loss_fn(output.squeeze(), labels) # Zero out grads, do backward, and update optimizer optimizer.zero_grad() loss.backward() optimizer.step() # Print what's happening with loss per 30 epochs if epoch % 20 == 0: print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}") # Report and record metrics, checkpoint model at end of each # epoch session.report({"loss": loss.item(), "epoch": epoch}, checkpoint=Checkpoint.from_dict( dict(epoch=epoch, model=model.state_dict())) ) torch.manual_seed(42) train_dataset = ray.data.from_items( [{"x": x, "y": 2 * x + 1} for x in range(200)] ) # Define scaling and run configs scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=scaling_config, run_config=run_config, datasets={"train": train_dataset}) result = trainer.fit() best_checkpoint_loss = result.metrics['loss'] # Assert loss is less 0.09 assert best_checkpoint_loss <= 0.09 # doctest: +SKIP Args: train_loop_per_worker: The training function to execute. This can either take in no arguments or a ``config`` dict. train_loop_config: Configurations to pass into ``train_loop_per_worker`` if it accepts an argument. torch_config: Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the ``backend_config`` arg of ``DataParallelTrainer``. scaling_config: Configuration for how to scale data parallel training. dataset_config: Configuration for dataset ingest. run_config: Configuration for the execution of the training run. datasets: Any Datasets to use for training. Use the key "train" to denote which dataset is the training dataset. If a ``preprocessor`` is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the ``preprocessor`` if one is provided. preprocessor: A ``ray.data.Preprocessor`` to preprocess the provided datasets. resume_from_checkpoint: A checkpoint to resume training from. """ def __init__( self, train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], *, train_loop_config: Optional[Dict] = None, torch_config: Optional[TorchConfig] = None, scaling_config: Optional[ScalingConfig] = None, dataset_config: Optional[Dict[str, DatasetConfig]] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, preprocessor: Optional["Preprocessor"] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): if not torch_config: torch_config = TorchConfig() super(TorchTrainer, self).__init__( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, backend_config=torch_config, scaling_config=scaling_config, dataset_config=dataset_config, run_config=run_config, datasets=datasets, preprocessor=preprocessor, resume_from_checkpoint=resume_from_checkpoint, )