Ray Train API
Contents
Ray Train API#
This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.
For core Ray AIR APIs, take a look at the AIR package reference.
Ray Train Base Classes (Developer APIs)#
Trainer Base Classes#
|
Defines interface for distributed training on Ray. |
|
A Trainer for data parallel training. |
|
Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks. |
BaseTrainer
Methods#
|
Runs training. |
|
Called during fit() to perform initial setup on the Trainer. |
Called during fit() to preprocess dataset attributes with preprocessor. |
|
Loop called by fit() to run training and report results to Tune. |
|
Convert self to a |
Train Backend Base Classes#
|
Singleton for distributed communication backend. |
Parent class for configurations of training backend. |
Ray Train Integrations#
PyTorch#
|
A Trainer for data parallel PyTorch training. |
|
Configuration for torch process group setup. |
|
A |
PyTorch Training Loop Utilities#
|
Prepares the model for distributed execution. |
|
Wraps optimizer to support automatic mixed precision. |
|
Prepares DataLoader for distributed execution. |
Gets the correct torch device to use for training. |
|
|
Enables training optimizations. |
|
Computes the gradient of the specified tensor w.r.t. |
|
Limits sources of nondeterministic behavior. |
Tensorflow/Keras#
|
A Trainer for data parallel Tensorflow training. |
PublicAPI (beta): This API is in beta and may change before becoming stable. |
|
|
A |
Tensorflow/Keras Training Loop Utilities#
|
A utility function that overrides default config for Tensorflow Dataset. |
|
Keras callback for Ray AIR reporting and checkpointing. |
Horovod#
|
A Trainer for data parallel Horovod training. |
|
Configurations for Horovod setup. |
XGBoost#
|
A Trainer for data parallel XGBoost training. |
|
A |
LightGBM#
|
A Trainer for data parallel LightGBM training. |
|
A |
HuggingFace#
|
A Trainer for data parallel HuggingFace Transformers on PyTorch training. |
|
A |
Scikit-Learn#
|
A Trainer for scikit-learn estimator training. |
|
A |
Mosaic#
|
A Trainer for data parallel Mosaic Composers on PyTorch training. |
Reinforcement Learning (RLlib)#
|
Reinforcement learning trainer. |
|
A |