Ray Train API#
Important
These API references are for the revamped Ray Train V2 implementation that is available starting from Ray 2.43
by enabling the environment variable RAY_TRAIN_V2_ENABLED=1. These APIs assume that the environment variable has been enabled.
See Ray Train V1 API for the old API references and the Ray Train V2 Migration Guide.
PyTorch Ecosystem#
| A Trainer for data parallel PyTorch training. | |
| Configuration for torch process group setup. | |
| Configuration for torch XLA setup. | 
PyTorch#
| Prepares the model for distributed execution. | |
| Prepares  | |
| Limits sources of nondeterministic behavior. | 
PyTorch Lightning#
| Prepare the PyTorch Lightning Trainer for distributed execution. | |
| Setup Lightning DDP training environment for Ray cluster. | |
| Subclass of DDPStrategy to ensure compatibility with Ray orchestration. | |
| Subclass of FSDPStrategy to ensure compatibility with Ray orchestration. | |
| Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration. | |
| A simple callback that reports checkpoints to Ray on train epoch end. | 
Hugging Face Transformers#
| Prepare your HuggingFace Transformer Trainer for Ray Train. | |
| A simple callback to report checkpoints and metrics to Ray Train. | 
More Frameworks#
TensorFlow/Keras#
| A Trainer for data parallel Tensorflow training. | |
| PublicAPI (beta): This API is in beta and may change before becoming stable. | |
| A utility function that overrides default config for Tensorflow Dataset. | |
| Keras callback for Ray Train reporting and checkpointing. | 
XGBoost#
| A Trainer for distributed data-parallel XGBoost training. | |
| XGBoost callback to save checkpoints and report metrics. | 
LightGBM#
| A Trainer for distributed data-parallel LightGBM training. | |
| Returns the network parameters to enable LightGBM distributed training. | |
| Creates a callback that reports metrics and checkpoints model. | 
JAX#
| A Trainer for Single-Program Multi-Data (SPMD) JAX training. | 
Ray Train Configuration#
| Configurable parameters for defining the checkpointing strategy. | |
| Class responsible for configuring Train dataset preprocessing. | |
| Configuration related to failure handling of each training run. | |
| Runtime configuration for training runs. | |
| Configuration for scaling training. | 
Ray Train Utilities#
Classes
| A reference to data persisted as a directory in local or remote storage. | |
| The manner in which we want to upload the checkpoint. | |
| Abstract interface for training context. | 
Functions
| Get all the reported checkpoints so far. | |
| Access the latest reported checkpoint to resume from if one exists. | |
| Get or create a singleton training context. | |
| Returns the  | |
| Report metrics and optionally save a checkpoint. | 
Collective
| Create a barrier across all workers. | |
| Broadcast small (<1kb) data from the rank 0 worker to all other workers. | 
Ray Train Output#
| A user-reported checkpoint and its associated metrics. | |
Ray Train Errors#
| Exception raised when training fails due to a controller error. | |
| Exception raised from the worker group during training. | |
| Exception raised when training fails from a  | 
Ray Tune Integration Utilities#
| Propagate metrics and checkpoint paths from Ray Train workers to Ray Tune. | 
Ray Train Developer APIs#
Trainer Base Class#
| Base class for distributed data parallel training on Ray. | 
Train Backend Base Classes#
| Singleton for distributed communication backend. | |
| Parent class for configurations of training backend. | 
Trainer Callbacks#
| Callback interface for custom user-defined callbacks to handling events during training. |