Ray Train API#

Important

These API references are for the revamped Ray Train V2 implementation that is available starting from Ray 2.43 by enabling the environment variable RAY_TRAIN_V2_ENABLED=1. These APIs assume that the environment variable has been enabled.

See Ray Train V1 API for the old API references.

PyTorch Ecosystem#

TorchTrainer

A Trainer for data parallel PyTorch training.

TorchConfig

Configuration for torch process group setup.

TorchXLAConfig

Configuration for torch XLA setup.

PyTorch#

get_device

Gets the correct torch device configured for this process.

get_devices

Gets the correct torch device list configured for this process.

prepare_model

Prepares the model for distributed execution.

prepare_data_loader

Prepares DataLoader for distributed execution.

enable_reproducibility

Limits sources of nondeterministic behavior.

PyTorch Lightning#

prepare_trainer

Prepare the PyTorch Lightning Trainer for distributed execution.

RayLightningEnvironment

Setup Lightning DDP training environment for Ray cluster.

RayDDPStrategy

Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

RayFSDPStrategy

Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.

RayDeepSpeedStrategy

Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

RayTrainReportCallback

A simple callback that reports checkpoints to Ray on train epoch end.

Hugging Face Transformers#

prepare_trainer

Prepare your HuggingFace Transformer Trainer for Ray Train.

RayTrainReportCallback

A simple callback to report checkpoints and metrics to Ray Train.

More Frameworks#

Tensorflow/Keras#

TensorflowTrainer

A Trainer for data parallel Tensorflow training.

TensorflowConfig

PublicAPI (beta): This API is in beta and may change before becoming stable.

prepare_dataset_shard

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback

Keras callback for Ray Train reporting and checkpointing.

XGBoost#

XGBoostTrainer

A Trainer for distributed data-parallel XGBoost training.

RayTrainReportCallback

XGBoost callback to save checkpoints and report metrics.

LightGBM#

LightGBMTrainer

A Trainer for distributed data-parallel LightGBM training.

RayTrainReportCallback

Creates a callback that reports metrics and checkpoints model.

Ray Train Configuration#

CheckpointConfig

Configurable parameters for defining the checkpointing strategy.

DataConfig

Class responsible for configuring Train dataset preprocessing.

FailureConfig

Configuration related to failure handling of each training run.

RunConfig

Runtime configuration for training runs.

ScalingConfig

Configuration for scaling training.

Ray Train Utilities#

Classes

Checkpoint

A reference to data persisted as a directory in local or remote storage.

TrainContext

Functions

get_checkpoint

Access the latest reported checkpoint to resume from if one exists.

get_context

Get or create a singleton training context.

get_dataset_shard

Returns the ray.data.DataIterator shard for this worker.

report

Report metrics and optionally save a checkpoint.

Ray Train Output#

Result

Ray Train Errors#

TrainingFailedError

Exception raised by <Framework>Trainer.fit() when training fails.

Ray Tune Integration Utilities#

tune.integration.ray_train.TuneReportCallback

Propagate metrics and checkpoint paths from Ray Train workers to Ray Tune.

Ray Train Developer APIs#

Trainer Base Class#

DataParallelTrainer

Base class for distributed data parallel training on Ray.

Train Backend Base Classes#

Backend

Singleton for distributed communication backend.

BackendConfig

Parent class for configurations of training backend.

Trainer Callbacks#

UserCallback

Callback interface for custom user-defined callbacks to handling events during training.