ray.train.trainer.BaseTrainer#
- class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#
Bases:
ABCDefines interface for distributed training on Ray.
Note: The base
BaseTrainerclass cannot be instantiated directly. Only one of its subclasses can be used.Note to developers: If a new trainer is added, please update
air/_internal/usage.py.How does a trainer work?
First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in
__init__.Then, when you call
trainer.fit(), the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor.trainer.setup(): Any heavyweight Trainer setup should be specified here.trainer.training_loop(): Executes the main training logic.Calling
trainer.fit()will return aray.result.Resultobject where you can access metrics from your training run, as well as any checkpoints that may have been saved.
How do I create a new Trainer?
Subclass
ray.train.trainer.BaseTrainer, and override thetraining_loopmethod, and optionallysetup.import torch from ray.train.trainer import BaseTrainer from ray import train, tune class MyPytorchTrainer(BaseTrainer): def setup(self): self.model = torch.nn.Linear(1, 1) self.optimizer = torch.optim.SGD( self.model.parameters(), lr=0.1) def training_loop(self): # You can access any Trainer attributes directly in this method. # self.datasets["train"] has already been dataset = self.datasets["train"] torch_ds = dataset.iter_torch_batches(dtypes=torch.float) loss_fn = torch.nn.MSELoss() for epoch_idx in range(10): loss = 0 num_batches = 0 torch_ds = dataset.iter_torch_batches( dtypes=torch.float, batch_size=2 ) for batch in torch_ds: X = torch.unsqueeze(batch["x"], 1) y = torch.unsqueeze(batch["y"], 1) # Compute prediction error pred = self.model(X) batch_loss = loss_fn(pred, y) # Backpropagation self.optimizer.zero_grad() batch_loss.backward() self.optimizer.step() loss += batch_loss.item() num_batches += 1 loss /= num_batches # Use Tune functions to report intermediate # results. train.report({"loss": loss, "epoch": epoch_idx}) # Initialize the Trainer, and call Trainer.fit() import ray train_dataset = ray.data.from_items( [{"x": i, "y": i} for i in range(10)]) my_trainer = MyPytorchTrainer(datasets={"train": train_dataset}) result = my_trainer.fit()
- Parameters:
scaling_config – Configuration for how to scale training.
run_config – Configuration for the execution of the training run.
datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.
metadata – Dict that should be made available via
train.get_context().get_metadata()and incheckpoint.get_metadata()for checkpoints saved from this Trainer. Must be JSON-serializable.resume_from_checkpoint – A checkpoint to resume training from.
DeveloperAPI: This API may change across minor Ray releases.
Methods
Converts self to a
tune.Trainableclass.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Deprecated.
Restores a Train experiment from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.
Loop called by fit() to run training and report results to Tune.