RaySGD API Reference¶
PyTorch¶
TorchTrainer¶
-
class
ray.util.sgd.torch.
TorchTrainer
(*, training_operator_cls, initialization_hook=None, config=None, num_workers=1, num_cpus_per_worker=1, use_gpu='auto', backend='auto', wrap_ddp=True, timeout_s=1800, use_fp16=False, use_tqdm=False, add_dist_sampler=True, scheduler_step_freq=None, use_local=False, num_replicas=None, batch_size=None, model_creator=None, data_creator=None, optimizer_creator=None, scheduler_creator=None, loss_creator=None, serialize_data_creation=None, data_loader_args=None, apex_args=None)[source]¶ Train a PyTorch model using distributed PyTorch.
Launches a set of actors which connect via distributed PyTorch and coordinate gradient updates to train the provided model. If Ray is not initialized, TorchTrainer will automatically initialize a local Ray cluster for you. Be sure to run ray.init(address=”auto”) to leverage multi-node training.
class MyTrainingOperator(TrainingOperator): def setup(self, config): model = nn.Linear(1, 1) optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 1e-4)) loss = torch.nn.MSELoss() batch_size = config["batch_size"] train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) train_loader = DataLoader(train_data, batch_size=batch_size) val_loader = DataLoader(val_data, batch_size=batch_size) self.model, self.optimizer = self.register( models=model, optimizers=optimizer, criterion=loss) self.register_data( train_loader=train_loader, validation_loader=val_loader) trainer = TorchTrainer( training_operator_cls=MyTrainingOperator, config={"batch_size": 32}, use_gpu=True ) for i in range(4): trainer.train()
- Parameters
training_operator_cls (type) – Custom training operator class that subclasses the TrainingOperator class. This class will be copied onto all remote workers and used to specify training components and custom training and validation operations.
initialization_hook (function) – A function to call on all training workers when they are first initialized. This could be useful to set environment variables for all the worker processes.
config (dict) – Custom configuration value to be passed to all operator constructors.
num_workers (int) – the number of workers used in distributed training. If 1, the worker will not be wrapped with DistributedDataParallel. TorchTrainer will scale down the number of workers if enough resources are not available, and will scale back up once they are. The total number of workers will never exceed num_workers amount.
num_cpus_per_worker (int) – Sets the cpu requirement for each worker.
use_gpu (bool) – Sets resource allocation for workers to 1 GPU if true, and automatically moves both the model and optimizer to the available CUDA device.
backend (string) – backend used by distributed PyTorch. Currently support “nccl”, “gloo”, and “auto”. If “auto”, RaySGD will automatically use “nccl” if use_gpu is True, and “gloo” otherwise.
wrap_ddp (bool) – Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself.
timeout_s (float) – Seconds before the torch process group times out. Useful when machines are unreliable. If not set, default to 30 min, which is the same default as
torch.init_process_group(...)
.add_dist_sampler (bool) – Whether to automatically add a DistributedSampler to all created dataloaders. Only applicable if num_workers > 1.
use_fp16 (bool) – Enables mixed precision training via apex if apex is installed. This is automatically done after the model and optimizers are constructed and will work for multi-model training. Please see https://github.com/NVIDIA/apex for more details.
scheduler_step_freq – “batch”, “epoch”, “manual”, or None. This will determine when
scheduler.step
is called. If “batch”,step
will be called after every optimizer step. If “epoch”,step
will be called after one pass of the DataLoader. If “manual”, the scheduler will not be incremented automatically - you are expected to calltrainer.update_scheduler
manually. If a scheduler is passed in, this value is expected to not be None.use_local (bool) – If True, 1 worker will be a local worker running on the driver process, and all other workers will be remote. If False, all workers will be remote. Set this to True for easy debugging of worker on driver process, but could also lead to issues with Cuda devices. Defaults to False.
-
train
(num_steps=None, profile=False, reduce_results=True, max_retries=3, info=None, dataset=None)[source]¶ Runs a training epoch.
Calls operator.train_epoch() on N parallel workers simultaneously underneath the hood.
Set max_retries to enable fault handling in case of instance preemption.
- Parameters
num_steps (int) – Number of batches to compute update steps on per worker. This corresponds also to the number of times
TrainingOperator.train_batch
is called per worker.profile (bool) – Returns time stats for the training procedure.
reduce_results (bool) – Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts.
max_retries (int) – Must be non-negative. If set to N, TorchTrainer will detect and recover from training failure. The recovery process will kill all current workers, query the Ray global state for total available resources, and re-launch up to the available resources. Behavior is not well-defined in case of shared cluster usage. Defaults to 3.
info (dict) – Optional dictionary passed to the training operator for
train_epoch
andtrain_batch
.dataset (Dataset) – Optional dataset to train with. If specified, the dataloader passed in via data_creator will be ignored.
- Returns
- (dict | list) A dictionary of metrics for training.
You can provide custom metrics by implementing a custom training loop. If
reduce_results=False
, this will return a list of metric dictionaries whose length will be equal tonum_workers
.
-
apply_all_workers
(fn)[source]¶ Run a function on all operators on the workers.
- Parameters
fn (Callable) – A function that takes in no arguments.
- Returns
A list of objects returned by
fn
on each worker.
-
apply_all_operators
(fn)[source]¶ Run a function on all operators on the workers.
- Parameters
fn (Callable[TrainingOperator]) – A function that takes in a TrainingOperator.
- Returns
A list of objects returned by
fn
on each operator.
-
validate
(num_steps=None, profile=False, reduce_results=True, info=None)[source]¶ Evaluates the model on the validation data set.
- Parameters
num_steps (int) – Number of batches to compute update steps on per worker. This corresponds also to the number of times
TrainingOperator.validate_batch
is called per worker.profile (bool) – Returns time stats for the evaluation procedure.
reduce_results (bool) – Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts.
info (dict) – Optional dictionary passed to the training operator for validate and validate_batch.
- Returns
- A dictionary of metrics for validation.
You can provide custom metrics by passing in a custom
training_operator_cls
.
-
update_scheduler
(metric)[source]¶ Calls
scheduler.step(metric)
on all registered schedulers.This is useful for lr_schedulers such as
ReduceLROnPlateau
.
-
get_local_operator
()[source]¶ Returns the local TrainingOperator object.
Be careful not to perturb its state, or else you can cause the system to enter an inconsistent state.
- Returns
The local TrainingOperator object.
- Return type
-
save
(checkpoint)[source]¶ Saves the Trainer state to the provided checkpoint path.
- Parameters
checkpoint (str) – Path to target checkpoint file.
-
load
(checkpoint)[source]¶ Loads the Trainer and all workers from the provided checkpoint.
- Parameters
checkpoint (str) – Path to target checkpoint file.
-
shutdown
(force=False)[source]¶ Shuts down workers and releases resources.
- Parameters
force (bool) – If True, forcefully kill all workers. If False, attempt a graceful shutdown first, and then forcefully kill if unsuccessful.
-
classmethod
as_trainable
(*args, override_tune_step=None, **kwargs)[source]¶ Creates a BaseTorchTrainable class compatible with Tune.
Any configuration parameters will be overridden by the Tune Trial configuration. You can also pass in a custom
override_tune_step
to implement your own iterative optimization routine and override the default implementation.def step(trainer, info): # Implement custom objective function here. train_stats = trainer.train() ... # Return the metrics to report to tune. # Do not call tune.report here. return train_stats TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, num_workers=2, use_gpu=True, override_tune_step=step ) analysis = tune.run( TorchTrainable, config={"lr": tune.grid_search([0.01, 0.1])} )
- Parameters
override_tune_step (Callable[[TorchTrainer, Dict], Dict]) – A function to override the default training step to be used for Ray Tune. It accepts two arguments: the first one is an instance of your TorchTrainer, and the second one is a info dictionary, containing information about the Trainer state. If None is passed in, the default step function will be used: run 1 epoch of training, 1 epoch of validation, and report both results to Tune. Passing in
override_tune_step
is useful to define custom step functions, for example if you need to manually update the scheduler or want to run more than 1 training epoch for each tune iteration.
PyTorch TrainingOperator¶
-
class
ray.util.sgd.torch.
TrainingOperator
(config, world_rank, local_rank, is_distributed=False, device=None, use_gpu=False, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None)[source]¶ Abstract class to define training and validation state and logic.
You must subclass this class and override the
setup
method to define your training components such as the model, optimizer, data, loss, and scheduler. When you pass this class toTorchTrainer
, a copy of this class will be made on each worker.class MyTrainingOperator(TrainingOperator): def setup(self, config): model = nn.Linear(1, 1) optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 1e-4)) loss = torch.nn.MSELoss() batch_size = config["batch_size"] train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5) train_loader = DataLoader(train_data, batch_size=batch_size) val_loader = DataLoader(val_data, batch_size=batch_size) self.model, self.optimizer = self.register( models=model, optimizers=optimizer, criterion=loss) self.register_data( train_loader=train_loader, validation_loader=val_loader) trainer = TorchTrainer( training_operator_cls=MyTrainingOperator, config={"batch_size": 32}, use_gpu=True ) for i in range(4): trainer.train()
This class provides default implementations for training and validation. Set
self.model
,self.optimizer
, andself.criterion
to leverage the default training and validation loops. Ifself.scheduler
is set, it will only be called at a batch or epoch frequency, depending on the user parameter. Setscheduler_step_freq
inTorchTrainer
to either “batch” or “epoch” to increment the scheduler correctly during training. If using a learning rate scheduler that depends on validation loss, you can usetrainer.update_scheduler
.If you want to provide custom training and validation loops, you can do so using this class as well. There are two granularities that you can provide customization: per epoch or per batch. You do not need to override both.
If you are using multiple models, optimizers, or schedulers, you must implement custom training and validation.
- Raises
ValueError – You are expected to either set
self.model
,self.optimizer
, andself.criterion
instance attributes in setup or implement custom training & validation.
-
setup
(config)[source]¶ Override this method to implement operator setup.
You should call self.register and self.register_data here to register training components and data loaders with Ray SGD.
- Parameters
config (dict) – Custom configuration value to be passed to all creator and operator constructors. Same as
self.config
.
-
register
(*, models, optimizers, criterion=None, schedulers=None, ddp_args=None, apex_args=None)[source]¶ Registers parameters with Ray SGD and sets up training components.
By calling this method to register your models, optimizers, criterion, and schedulers, Ray SGD will automatically handle necessary setup such as GPU/devices, Distributed Data Parallel, and Fp16. The registered components are returned and should be set as instance attributes to access during training/validation.
If more than one model, optimizer, or scheduler is passed in, you should implement your own custom training loop.
- Calling register will perform the following steps in this order:
- If using GPU, Move model(s) and criterion to the corresponding
Cuda device.
- If using fp16, initializes amp with model(s), optimizer(s),
and apex_args.
- If using distributed training and wrap_ddp is True,
wraps model(s) with DistributedDataParallel.
class MyTrainingOperator(TrainingOperator): def setup(self, config): model = ... optimizer = ... train_loader = ... val_loader = ... loss = ... self.model, self.optimizer, self.criterion = self.register( models=model, optimizers=optimizer, criterion=loss) # At this point DDP, Cuda, and Fp16 # are set up for all our components. We then use # self.model, self.optimizer, etc. in our training loop. self.register_data(train_loader=train_loader, validation_loader=val_loader)
- Parameters
models (torch.nn.Module or Iterable[nn.Module]) – Pytorch model or multiple Pytorch models to use for training. If use_gpu=True is passed into
TorchTrainer
, and Cuda is available, models will automatically be placed on GPU. Ifwrap_ddp=True
is passed intoTorchTrainer
, models will be wrapped in DDP. If wrap_ddp is False, you should handle DDP for your models in setup.(torch.optim.Optimizer or Iterable[ (optimizers) – torch.optim.Optimizer]): Pytorch optimizer or multiple Pytorch optimizers to use for training.
criterion (Callable, optional) – Function to return loss metric given features and target. If not provided, must implement a custom training loop.
(torch.optim.lr_scheduler or Iterable[ (schedulers) – torch.optim.lr_scheduler], optional): A learning rate scheduler or multiple learning rate schedulers.
ddp_args (dict|None) – Dict containing keyword args for DistributedDataParallel if distributed training is being used. module and device_ids are automatically passed in, but this dict is useful for passing in other args such as find_unused_parameters=True.
apex_args (dict|None) – Dict containing keyword args for amp.initialize if fp16 is being used. See https://nvidia.github.io/apex/amp.html#module-apex.amp. By default, the models and optimizers are passed in. Consider using “num_losses” if operating over multiple models and optimizers.
- Returns
Tuple of model, optimizer, criterion if not None, and scheduler if not None.
-
register_data
(*, train_loader=None, validation_loader=None)[source]¶ Registers data loaders with Ray SGD.
Calling this method will automatically setup Distributed Sampler for these data loaders if add_dist_sampler=True is passed into the TorchTrainer. This method does not return the wrapped data loaders. You should use the iterators passed into train_epoch and validate instead.
class MyTrainingOperator(TrainingOperator): def setup(self, config): model = ... optimizer = ... train_loader = ... val_loader = ... loss = ... self.model, self.optimizer, self.criterion = self.register( models=model, optimizers=optimizer, criterion=loss) self.register_data(train_loader=train_loader, validation_loader=val_loader) # At this point the data loaders are registered with # Ray SGD and are wrapped with Distributed Samplers if # applicable. def train_epoch(self, iterator, info): # If providing custom training or validation methods, # the registered data loaders are passed in through the # iterator parameter. ...
- Parameters
train_loader (Iterator) – An iterator for training data. If None is explicitly passed in, a Ray SGD Dataset must be passed in through TorchTrainer.train. Ray SGD will automatically use a Distributed Sampler if TorchTrainer(…, add_dist_sampler=True).
validation_loader (Iterator) – An iterator for validation data. Ray SGD will automatically use a Distributed Sampler if TorchTrainer(…, add_dist_sampler=True).
-
train_epoch
(iterator, info)[source]¶ Runs one standard training pass over the training dataloader.
By default, this method will iterate over the given iterator and call
self.train_batch
over each batch. Ifscheduler_step_freq
is set, this default method will also step the scheduler accordingly.You do not need to call
train_batch
in this method if you plan to implement a custom optimization/training routine here.You may find
ray.util.sgd.utils.AverageMeterCollection
useful when overriding this method. See example below:def train_epoch(self, ...): meter_collection = AverageMeterCollection() self.model.train() for batch in iterator: # do some processing metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics # This keeps track of all metrics across multiple batches meter_collection.update(metrics, n=len(batch)) # Returns stats of the meters. stats = meter_collection.summary() return stats
- Parameters
iterator (iter) – Iterator over the training data for the entire epoch. This iterator is expected to be entirely consumed.
info (dict) – Dictionary for information to be used for custom training operations.
- Returns
A dict of metrics from training.
-
train_batch
(batch, batch_info)[source]¶ Computes loss and updates the model over one batch.
This method is responsible for computing the loss and gradient and updating the model.
By default, this method implementation assumes that batches are in (*features, labels) format. So we also support multiple inputs model. If using amp/fp16 training, it will also scale the loss automatically.
You can provide custom loss metrics and training operations if you override this method.
You do not need to override this method if you plan to override
train_epoch
.- Parameters
batch – One item of the validation iterator.
batch_info (dict) – Information dict passed in from
train_epoch
.
- Returns
- A dictionary of metrics.
By default, this dictionary contains “loss” and “num_samples”. “num_samples” corresponds to number of datapoints in the batch. However, you can provide any number of other values. Consider returning “num_samples” in the metrics because by default,
train_epoch
uses “num_samples” to calculate averages.
-
validate
(val_iterator, info)[source]¶ Runs one standard validation pass over the val_iterator.
This will call
model.eval()
andtorch.no_grad
when iterating over the validation dataloader.You also do not need to call
validate_batch
if overriding this method.- Parameters
val_iterator (iter) – Iterable constructed from the validation dataloader.
info – (dict): Dictionary for information to be used for custom validation operations.
- Returns
- A dict of metrics from the evaluation.
By default, returns “val_accuracy” and “val_loss” which is computed by aggregating “loss” and “correct” values from
validate_batch
and dividing it by the sum ofnum_samples
from all calls toself.validate_batch
.
-
validate_batch
(batch, batch_info)[source]¶ Calcuates the loss and accuracy over a given batch.
You can override this method to provide arbitrary metrics.
Same as
train_batch
, this method implementation assumes that batches are in (*features, labels) format by default. So we also support multiple inputs model.- Parameters
batch – One item of the validation iterator.
batch_info (dict) – Contains information per batch from
validate()
.
- Returns
- A dict of metrics.
By default, returns “val_loss”, “val_accuracy”, and “num_samples”. When overriding, consider returning “num_samples” in the metrics because by default,
validate
uses “num_samples” to calculate averages.
-
state_dict
()[source]¶ Override this to return a representation of the operator state. Any argument passed into self.register and self.register_data will automatically be saved. Use this method to save any additional state. If your TorchTrainer is on a CPU-only machine, make sure this method converts all state to be CPU-compatible.
- Returns
The state dict of the operator.
- Return type
dict
-
load_state_dict
(state_dict)[source]¶ Override this to load the representation of the operator state. Anything passed into self.register and self.register_data will automatically be loaded. Use this method to load any additional state. :param state_dict: State dict as returned by the operator. :type state_dict: dict
-
classmethod
from_ptl
(lightning_module_cls, train_dataloader=None, val_dataloader=None)[source]¶ Create a custom TrainingOperator class from a LightningModule.
MyLightningOperator = TrainingOperator.from_ptl( MyLightningModule) trainer = TorchTrainer(training_operator_cls=MyLightningOperator, ...)
- Parameters
lightning_module_cls – Your LightningModule class. An object of this class will get instantiated on each worker.
train_dataloader – The data loader to use for training. If None is provided, LightningModule.train_dataloader will be used instead.
val_dataloader – The data loader to use for validation. If None is provided, LightningModule.val_dataloader will be used instead.
- Returns
A TrainingOperator class properly configured given the LightningModule.
-
classmethod
from_creators
(model_creator, optimizer_creator, data_creator=None, loss_creator=None, scheduler_creator=None, serialize_data_creation=True)[source]¶ Create a custom TrainingOperator class from creator functions.
This method is useful for backwards compatibility with previous versions of Ray. To provide custom training and validation, you should subclass the class that is returned by this method instead of
TrainingOperator
.MyCreatorOperator = TrainingOperator.from_creators( model_creator, optimizer_creator) trainer = TorchTrainer(training_operator_cls=MyCreatorOperator, ...)
- Parameters
model_creator (dict -> Model(s)) – Constructor function that takes in config and returns the model(s) to be optimized. These must be
torch.nn.Module
objects. If multiple models are returned, atraining_operator_cls
must be specified. You do not need to handle GPU/devices in this function; RaySGD will do that under the hood.data_creator (dict -> Iterable(s)) – Constructor function that takes in the passed config and returns one or two Iterable objects. Note that even though two Iterable objects can be returned, only one will be used for training, and the other will be used for validation. If not provided, you must pass in a Dataset to
TorchTrainer.train
.optimizer_creator ((models, dict) -> optimizers) – Constructor function that takes in the return values from
model_creator
and the passed config and returns One or more Torch optimizer objects. You do not need to handle GPU/devices in this function;RaySGD
will do that for you.loss_creator (torch.nn.*Loss class | dict -> loss) – A constructor function for the training loss. This can be either a function that takes in the provided config for customization or a subclass of
torch.nn.modules.loss._Loss
, which is most Pytorch loss classes. For example,loss_creator=torch.nn.BCELoss
. If not provided, you must provide a custom TrainingOperator.scheduler_creator ((optimizers, dict) -> scheduler) – A constructor function for the torch scheduler. This is a function that takes in the generated optimizers (from
optimizer_creator
) provided config for customization. Be sure to setscheduler_step_freq
to increment the scheduler correctly.serialize_data_creation (bool) – A filelock will be used to ensure no race conditions in data downloading among different workers on the same node (using the local file system). Defaults to True.
- Returns
A CreatorOperator class- a subclass of TrainingOperator with a
setup
method that utilizes the passed in creator functions.
-
property
device
¶ The appropriate torch device, at your convenience.
- Type
torch.device
-
property
config
¶ Provided into TorchTrainer.
- Type
dict
-
property
world_rank
¶ The rank of the parent runner. Always 0 if not distributed.
- Type
int
-
property
local_rank
¶ Local rank of parent runner. Always 0 if not distributed.
- Type
int
-
property
use_gpu
¶ Returns True if cuda is available and use_gpu is True.
-
property
use_fp16
¶ Whether the model and optimizer have been FP16 enabled.
- Type
bool
-
property
use_tqdm
¶ Whether tqdm progress bars are enabled.
- Type
bool
-
property
device_ids
¶ Device IDs for the model.
This is useful for using batch norm with DistributedDataParallel. Not applicable if not using GPU.
- Type
Optional[List[int]]
-
property
scheduler_step_freq
¶ The
scheduler_step_freq
passed intoTorchTrainer
This is useful to determine when to call scheduler.step.
- Type
Optional[str]
CreatorOperator¶
-
class
ray.util.sgd.torch.training_operator.
CreatorOperator
(config, world_rank, local_rank, is_distributed=False, device=None, use_gpu=False, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None)[source]¶ A subclass of TrainingOperator with training defined by creator funcs.
This class allows for backwards compatibility with pre Ray 1.0 versions.
This class is returned by TrainingOperator.from_creators(…). If you need to add custom functionality, you should subclass this class, implement the appropriate methods and pass the subclass into TorchTrainer.
MyCreatorOperator = TrainingOperator.from_creators( model_creator, optimizer_creator) trainer = TorchTrainer(training_operator_cls=MyCreatorOperator, ...)
-
property
model
¶ First or only model created by the provided
model_creator
.
-
property
optimizer
¶ First or only optimizer(s) created by the
optimizer_creator
.
-
property
scheduler
¶ First or only scheduler(s) created by the
scheduler_creator
.
-
property
criterion
¶ Criterion created by the provided
loss_creator
.
-
property
models
¶ List of models created by the provided
model_creator
.
-
property
optimizers
¶ List of optimizers created by the
optimizer_creator
.
-
property
schedulers
¶ List of schedulers created by the
scheduler_creator
.
-
property
Pytorch Lightning LightningOperator¶
-
class
ray.util.sgd.torch.lightning_operator.
LightningOperator
(config, world_rank, local_rank, is_distributed=False, device=None, use_gpu=False, use_fp16=False, use_tqdm=False, wrap_ddp=False, add_dist_sampler=False, scheduler_step_freq=None)[source]¶ A subclass of TrainingOperator created from a PTL
LightningModule
.This class is returned by TrainingOperator.from_ptl and it’s training state is defined by the Pytorch Lightning
LightningModule
that is passed into from_ptl. Training and validation functionality have already been implemented according to Pytorch Lightning’s Trainer. But if you need to modify training, you should subclass this class and override the appropriate methods before passing in the subclass to TorchTrainer.MyLightningOperator = TrainingOperator.from_ptl( MyLightningModule) trainer = TorchTrainer(training_operator_cls=MyLightningOperator, ...)
-
property
model
¶ The LightningModule to use for training.
The returned model is wrapped in DDP if using distributed training.
-
property
scheduler_dicts
¶ Returns list of scheduler dictionaries.
List is empty if no schedulers are returned in the configure_optimizers method of your LightningModule.
Default configuration is used if configure_optimizers returns scheduler objects.
See https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers
-
property
optimizers
¶ Returns list of optimizers as returned by configure_optimizers.
-
property
schedulers
¶ Returns list of schedulers as returned by configure_optimizers.
List is empty if no schedulers are returned in configure_optimizers.
-
property
BaseTorchTrainable¶
-
class
ray.util.sgd.torch.
BaseTorchTrainable
(config=None, logger_creator=None)[source]¶ Base class for converting TorchTrainer to a Trainable class.
This class is produced when you call
TorchTrainer.as_trainable(...)
.By default one step of training runs
trainer.train()
once andtrainer.validate()
once. You can implement custom iterative training procedures by passing in aoverride_tune_step
function toas_trainable
:def custom_step(trainer, info): for i in range(5): train_stats = trainer.train() validation_stats = trainer.validate() train_stats.update(validation_stats) return train_stats # TorchTrainable is subclass of BaseTorchTrainable. TorchTrainable = TorchTrainer.as_trainable( training_operator_cls=MyTrainingOperator, num_workers=2, use_gpu=True, override_tune_step=custom_step ) analysis = tune.run( TorchTrainable, config={"lr": tune.grid_search([0.01, 0.1])} )
-
load_checkpoint
(checkpoint_path)[source]¶ Restores the trainer state.
Override this if you have state external to the Trainer object.
-
property
trainer
¶ An instantiated TorchTrainer object.
Use this when specifying custom training procedures for Tune.
-
Tensorflow¶
TFTrainer¶
-
class
ray.util.sgd.tf.
TFTrainer
(model_creator, data_creator, config=None, num_replicas=1, num_cpus_per_worker=1, use_gpu=False, verbose=False)[source]¶ -
__init__
(model_creator, data_creator, config=None, num_replicas=1, num_cpus_per_worker=1, use_gpu=False, verbose=False)[source]¶ Sets up the TensorFlow trainer.
- Parameters
model_creator (dict -> Model) – This function takes in the config dict and returns a compiled TF model.
data_creator (dict -> tf.Dataset, tf.Dataset) – Creates the training and validation data sets using the config. config dict is passed into the function.
config (dict) – configuration passed to ‘model_creator’, ‘data_creator’. Also contains fit_config, which is passed into model.fit(data, **fit_config) and evaluate_config which is passed into model.evaluate.
num_cpus_per_worker (int) – Sets the cpu requirement for each worker.
num_replicas (int) – Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster.
use_gpu (bool) – Enables all workers to use GPU.
verbose (bool) – Prints output of one model if true.
-
save
(checkpoint)[source]¶ Saves the model at the provided checkpoint.
- Parameters
checkpoint (str) – Path to target checkpoint file.
-
RaySGD Dataset¶
Dataset¶
-
class
ray.util.sgd.data.
Dataset
(data, batch_size=32, download_func=None, max_concurrency=0, transform=None)[source]¶ A simple Dataset abstraction for RaySGD.
This dataset is designed to work with RaySGD trainers (currently just Torch) to provide support for streaming large external datasets, and built in sharding.
def to_mat(x): return torch.tensor([[x]]).float() data = [i * 0.001 for i in range(1000)] p_iter = iter.from_items(data, num_shards=1, repeat=True) dataset = Dataset( p_iter, batch_size=32, max_concurrency=1, download_func=lambda x: (to_mat(x), to_mat(x))) trainer = TorchTrainer( model_creator=model_creator, data_creator=None, optimizer_creator=optimizer_creator, loss_creator=torch.nn.MSELoss, num_workers=5, ) for i in range(10): # Train for another epoch using the dataset trainer.train(dataset=dataset, num_steps=200) model = trainer.get_model() print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
- Parameters
data (iterable[U] or ParallelIterator[U]) – Any existing python iterable (or iterator), or an existing parallel iterator to use.
batch_size (int) – The batch size for training/inference (default 32).
download_func (U -> (S, Y)) – A function which returns two values, the input and the label (default is the identity function).
max_concurrency (int) – The maximum number of concurrent calls to the download function. See ParallelIterator::for_each for details.
transform (S -> X) – A final transformation to be applied to the input only. This is guaranteed to run on the same worker that training will occur on.
RaySGD Utils¶
Utils¶
-
class
ray.util.sgd.utils.
AverageMeter
[source]¶ Utility for computing and storing the average and most recent value.
Example
>>> meter = AverageMeter() >>> meter.update(5) >>> meter.val, meter.avg, meter.sum (5, 5.0, 5) >>> meter.update(10, n=4) >>> meter.val, meter.avg, meter.sum (10, 9.0, 45)
-
class
ray.util.sgd.utils.
AverageMeterCollection
[source]¶ A grouping of AverageMeters.
This utility is used in TrainingOperator.train_epoch and TrainingOperator.validate to collect averages and most recent value across all batches. One AverageMeter object is used for each metric.
Example
>>> meter_collection = AverageMeterCollection() >>> meter_collection.update({"loss": 0.5, "acc": 0.5}, n=32) >>> meter_collection.summary() {'batch_count': 1, 'num_samples': 32, 'loss': 0.5, 'last_loss': 0.5, 'acc': 0.5, 'last_acc': 0.5} >>> meter_collection.update({"loss": 0.1, "acc": 0.9}, n=32) >>> meter_collection.summary() {'batch_count': 2, 'num_samples': 64, 'loss': 0.3, 'last_loss': 0.1, 'acc': 0.7, 'last_acc': 0.9}