ray.train.horovod.HorovodTrainer
ray.train.horovod.HorovodTrainer#
- class ray.train.horovod.HorovodTrainer(*args, **kwargs)[source]#
Bases:
ray.train.data_parallel_trainer.DataParallelTrainer
A Trainer for data parallel Horovod training.
This Trainer runs the function
train_loop_per_worker
on multiple Ray Actors. These actors already have the necessary Horovod setup already configured for distributed Horovod training.The
train_loop_per_worker
function is expected to take in either 0 or 1 arguments:def train_loop_per_worker(): ...
def train_loop_per_worker(config: Dict): ...
If
train_loop_per_worker
accepts an argument, thentrain_loop_config
will be passed in as the argument. This is useful if you want to tune the values intrain_loop_config
as hyperparameters.If the
datasets
dict contains a training dataset (denoted by the βtrainβ key), then it will be split into multiple dataset shards that can then be accessed bysession.get_dataset_shard("train")
insidetrain_loop_per_worker
. All the other datasets will not be split andsession.get_dataset_shard(...)
will return the the entire Dataset.Inside the
train_loop_per_worker
function, you can use any of the Ray AIR session methods.def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. session.report(...) # Returns dict of last saved checkpoint. session.get_checkpoint() # Returns the Ray Dataset shard for the given key. session.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. session.get_world_size() # Returns the rank of this worker. session.get_world_rank() # Returns the rank of the worker on the current node. session.get_local_rank()
Any returns from the
train_loop_per_worker
will be discarded and not used or persisted anywhere.You could use
TensorflowPredictor
orTorchPredictor
in conjunction with HorovodTrainer. You must save the model under the βmodelβ kwarg in theCheckpoint
passed tosession.report()
, so that it can be used by corresponding predictors.Example:
import ray import ray.train as train import ray.train.torch. # Need this to use `train.torch.get_device()` import horovod.torch as hvd import torch import torch.nn as nn from ray.air import session from ray.train.horovod import HorovodTrainer from ray.train.torch import TorchCheckpoint from ray.air.config import ScalingConfig # If using GPUs, set this to True. use_gpu = False input_size = 1 layer_size = 15 output_size = 1 num_epochs = 3 class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.layer1 = nn.Linear(input_size, layer_size) self.relu = nn.ReLU() self.layer2 = nn.Linear(layer_size, output_size) def forward(self, input): return self.layer2(self.relu(self.layer1(input))) def train_loop_per_worker(): hvd.init() dataset_shard = session.get_dataset_shard("train") model = NeuralNetwork() device = train.torch.get_device() model.to(device) loss_fn = nn.MSELoss() lr_scaler = 1 optimizer = torch.optim.SGD(model.parameters(), lr=0.1 * lr_scaler) # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), op=hvd.Average, ) for epoch in range(num_epochs): model.train() for batch in dataset_shard.iter_torch_batches( batch_size=32, dtypes=torch.float ): inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"] outputs = model(inputs) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print(f"epoch: {epoch}, loss: {loss.item()}") session.report( {}, checkpoint=TorchCheckpoint.from_state_dict( model.state_dict() ), ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) trainer = HorovodTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=scaling_config, datasets={"train": train_dataset}, ) result = trainer.fit()
- Parameters
train_loop_per_worker β The training function to execute. This can either take in no arguments or a
config
dict.train_loop_config β Configurations to pass into
train_loop_per_worker
if it accepts an argument.horovod_config β Configuration for setting up the Horovod backend. If set to None, use the default configuration. This replaces the
backend_config
arg ofDataParallelTrainer
.scaling_config β Configuration for how to scale data parallel training.
dataset_config β Configuration for dataset ingest.
run_config β Configuration for the execution of the training run.
datasets β Any Ray Datasets to use for training. Use the key βtrainβ to denote which dataset is the training dataset. If a
preprocessor
is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by thepreprocessor
if one is provided.preprocessor β A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint β A checkpoint to resume training from.
PublicAPI (beta): This API is in beta and may change before becoming stable.