Ray Train API
Contents
Ray Train API#
This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.
For core Ray AIR APIs, take a look at the AIR package reference.
Ray Train Base Classes (Developer APIs)#
Trainer Base Classes#
|
Defines interface for distributed training on Ray. |
|
A Trainer for data parallel training. |
|
Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks. |
BaseTrainer
API#
|
Runs training. |
|
Called during fit() to perform initial setup on the Trainer. |
Called during fit() to preprocess dataset attributes with preprocessor. |
|
Loop called by fit() to run training and report results to Tune. |
|
Convert self to a |
Train Backend Base Classes#
|
Singleton for distributed communication backend. |
Parent class for configurations of training backend. |
Ray Train Integrations#
PyTorch#
|
A Trainer for data parallel PyTorch training. |
|
Configuration for torch process group setup. |
|
A |
PyTorch Training Loop Utilities#
|
Prepares the model for distributed execution. |
|
Wraps optimizer to support automatic mixed precision. |
|
Prepares DataLoader for distributed execution. |
Gets the correct torch device configured for this process. |
|
|
Enables training optimizations. |
|
Computes the gradient of the specified tensor w.r.t. |
|
Limits sources of nondeterministic behavior. |
PyTorch Lightning#
|
A Trainer for data parallel PyTorch Lightning training. |
Configuration Class to pass into LightningTrainer. |
|
|
A |
Tensorflow/Keras#
|
A Trainer for data parallel Tensorflow training. |
PublicAPI (beta): This API is in beta and may change before becoming stable. |
|
|
A |
Tensorflow/Keras Training Loop Utilities#
|
A utility function that overrides default config for Tensorflow Dataset. |
|
Keras callback for Ray AIR reporting and checkpointing. |
Horovod#
|
A Trainer for data parallel Horovod training. |
|
Configurations for Horovod setup. |
XGBoost#
|
A Trainer for data parallel XGBoost training. |
|
A |
LightGBM#
|
A Trainer for data parallel LightGBM training. |
|
A |
HuggingFace#
|
A Trainer for data parallel HuggingFace Transformers on PyTorch training. |
|
A |
Scikit-Learn#
|
A Trainer for scikit-learn estimator training. |
|
A |
Mosaic#
|
A Trainer for data parallel Mosaic Composers on PyTorch training. |
Reinforcement Learning (RLlib)#
|
Reinforcement learning trainer. |
|
A |
Ray Train Experiment Restoration#
|
Restores a Train experiment from a previously interrupted/failed run. |
Note
All trainer classes have a restore
method that takes in a path
pointing to the directory of the experiment to be restored.
restore
also exposes a subset of construtor arguments that can be re-specified.
See Restoration API for Built-in Trainers
below for details on restore
arguments for different AIR trainer integrations.
Restoration API for Built-in Trainers#
|
Restores a DataParallelTrainer from a previously interrupted/failed run. |
Restores a HuggingFaceTrainer from a previously interrupted/failed run. |
Note
TorchTrainer.restore
, TensorflowTrainer.restore
, and HorovodTrainer.restore
can take in the same parameters as their parent class’s
DataParallelTrainer.restore
.
Unless otherwise specified, other trainers will accept the same parameters as
BaseTrainer.restore
.
See also
See How do I restore a Ray Train experiment? for more details on when and how trainer restore should be used.