ray.train.trainer.BaseTrainer#
- class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#
Bases:
ABC
Defines interface for distributed training on Ray.
Note: The base
BaseTrainer
class cannot be instantiated directly. Only one of its subclasses can be used.Note to developers: If a new trainer is added, please update
air/_internal/usage.py
.How does a trainer work?
First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in
__init__
.Then, when you call
trainer.fit()
, the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor.trainer.setup()
: Any heavyweight Trainer setup should be specified here.trainer.training_loop()
: Executes the main training logic.Calling
trainer.fit()
will return aray.result.Result
object where you can access metrics from your training run, as well as any checkpoints that may have been saved.
How do I create a new Trainer?
Subclass
ray.train.trainer.BaseTrainer
, and override thetraining_loop
method, and optionallysetup
.import torch from ray.train.trainer import BaseTrainer from ray import train, tune class MyPytorchTrainer(BaseTrainer): def setup(self): self.model = torch.nn.Linear(1, 1) self.optimizer = torch.optim.SGD( self.model.parameters(), lr=0.1) def training_loop(self): # You can access any Trainer attributes directly in this method. # self.datasets["train"] has already been dataset = self.datasets["train"] torch_ds = dataset.iter_torch_batches(dtypes=torch.float) loss_fn = torch.nn.MSELoss() for epoch_idx in range(10): loss = 0 num_batches = 0 torch_ds = dataset.iter_torch_batches( dtypes=torch.float, batch_size=2 ) for batch in torch_ds: X = torch.unsqueeze(batch["x"], 1) y = torch.unsqueeze(batch["y"], 1) # Compute prediction error pred = self.model(X) batch_loss = loss_fn(pred, y) # Backpropagation self.optimizer.zero_grad() batch_loss.backward() self.optimizer.step() loss += batch_loss.item() num_batches += 1 loss /= num_batches # Use Tune functions to report intermediate # results. train.report({"loss": loss, "epoch": epoch_idx}) # Initialize the Trainer, and call Trainer.fit() import ray train_dataset = ray.data.from_items( [{"x": i, "y": i} for i in range(10)]) my_trainer = MyPytorchTrainer(datasets={"train": train_dataset}) result = my_trainer.fit()
- Parameters:
scaling_config – Configuration for how to scale training.
run_config – Configuration for the execution of the training run.
datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.
metadata – Dict that should be made available via
train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.resume_from_checkpoint – A checkpoint to resume training from.
DeveloperAPI: This API may change across minor Ray releases.
Methods
Converts self to a
tune.Trainable
class.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Deprecated.
Restores a Train experiment from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.
Loop called by fit() to run training and report results to Tune.