ray.train.torch.TorchTrainer
ray.train.torch.TorchTrainer#
- class ray.train.torch.TorchTrainer(*args, **kwargs)[source]#
Bases:
ray.train.data_parallel_trainer.DataParallelTrainer
A Trainer for data parallel PyTorch training.
At a high level, this Trainer does the following:
Launches multiple workers as defined by the
scaling_config
.Sets up a distributed PyTorch environment on these workers as defined by the
torch_config
.Ingests the input
datasets
based on thedataset_config
.Runs the input
train_loop_per_worker(train_loop_config)
on all workers.
For more details, see the PyTorch User Guide.
Example
import os import tempfile import torch from torch import nn from torch.nn.parallel import DistributedDataParallel import ray from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig from ray.train.torch import TorchTrainer # If using GPUs, set this to True. use_gpu = False # Number of processes to run training on. num_workers = 4 # Define your network structure. class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.layer1 = nn.Linear(1, 32) self.relu = nn.ReLU() self.layer2 = nn.Linear(32, 1) def forward(self, input): return self.layer2(self.relu(self.layer1(input))) # Training loop. def train_loop_per_worker(config): # Read configurations. lr = config["lr"] batch_size = config["batch_size"] num_epochs = config["num_epochs"] # Fetch training dataset. train_dataset_shard = ray.train.get_dataset_shard("train") # Instantiate and prepare model for training. model = NeuralNetwork() model = ray.train.torch.prepare_model(model) # Define loss and optimizer. loss_fn = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=lr) # Create data loader. dataloader = train_dataset_shard.iter_torch_batches( batch_size=batch_size, dtypes=torch.float ) # Train multiple epochs. for epoch in range(num_epochs): # Train epoch. for batch in dataloader: output = model(batch["input"]) loss = loss_fn(output, batch["label"]) optimizer.zero_grad() loss.backward() optimizer.step() # Create checkpoint. base_model = (model.module if isinstance(model, DistributedDataParallel) else model) checkpoint_dir = tempfile.mkdtemp() torch.save( {"model_state_dict": base_model.state_dict()}, os.path.join(checkpoint_dir, "model.pt"), ) checkpoint = Checkpoint.from_directory(checkpoint_dir) # Report metrics and checkpoint. ray.train.report({"loss": loss.item()}, checkpoint=checkpoint) # Define configurations. train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32} scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu) run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) # Define datasets. train_dataset = ray.data.from_items( [{"input": [x], "label": [2 * x + 1]} for x in range(2000)] ) datasets = {"train": train_dataset} # Initialize the Trainer. trainer = TorchTrainer( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, scaling_config=scaling_config, run_config=run_config, datasets=datasets ) # Train the model. result = trainer.fit() # Inspect the results. final_loss = result.metrics["loss"]
- Parameters
train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single
Dict
argument which is set by definingtrain_loop_config
. Within this function you can use any of the Ray Train Loop utilities.train_loop_config – A configuration
Dict
to pass in as an argument totrain_loop_per_worker
. This is typically used for specifying hyperparameters.torch_config – The configuration for setting up the PyTorch Distributed backend. If set to None, a default configuration will be used in which GPU training uses NCCL and CPU training uses Gloo.
scaling_config – The configuration for how to scale data parallel training.
num_workers
determines how many Python processes are used for training, anduse_gpu
determines whether or not each process should use GPUs. SeeScalingConfig
for more info.run_config – The configuration for the execution of the training run. See
RunConfig
for more info.datasets – The Ray Datasets to ingest for training. Datasets are keyed by name (
{name: dataset}
). Each dataset can be accessed from within thetrain_loop_per_worker
by callingray.train.get_dataset_shard(name)
. Sharding and additional configuration can be done by passing in adataset_config
.dataset_config – The configuration for ingesting the input
datasets
. By default, all the Ray Dataset are split equally across workers. SeeDataConfig
for more details.resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within
train_loop_per_worker
by callingray.train.get_checkpoint()
.metadata – Dict that should be made available via
ray.train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.
PublicAPI: This API is stable across Ray releases.
Methods
Converts self to a
tune.Trainable
class.can_restore
(path[, storage_filesystem])Checks whether a given directory contains a restorable Train experiment.
fit
()Runs training.
Returns a copy of this Trainer's final dataset configs.
restore
(path[, train_loop_per_worker, ...])Restores a DataParallelTrainer from a previously interrupted/failed run.
setup
()Called during fit() to perform initial setup on the Trainer.