Ray Train V1 API#

Important

Ray Train V2 is an overhaul of Ray Train’s implementation and select APIs, which can be enabled by setting the environment variable RAY_TRAIN_V2_ENABLED=1 starting in Ray 2.43.

This page contains the deprecated V1 API references. See Ray Train API for the new V2 API references.

PyTorch Ecosystem#

TorchTrainer

A Trainer for data parallel PyTorch training.

TorchConfig

Configuration for torch process group setup.

TorchXLAConfig

Configuration for torch XLA setup.

PyTorch#

get_device

Gets the correct torch device configured for this process.

get_devices

Gets the correct torch device list configured for this process.

prepare_model

Prepares the model for distributed execution.

prepare_data_loader

Prepares DataLoader for distributed execution.

enable_reproducibility

Limits sources of nondeterministic behavior.

PyTorch Lightning#

prepare_trainer

Prepare the PyTorch Lightning Trainer for distributed execution.

RayLightningEnvironment

Setup Lightning DDP training environment for Ray cluster.

RayDDPStrategy

Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

RayFSDPStrategy

Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.

RayDeepSpeedStrategy

Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

RayTrainReportCallback

A simple callback that reports checkpoints to Ray on train epoch end.

Hugging Face Transformers#

prepare_trainer

Prepare your HuggingFace Transformer Trainer for Ray Train.

RayTrainReportCallback

A simple callback to report checkpoints and metrics to Ray Train.

More Frameworks#

Tensorflow/Keras#

TensorflowTrainer

A Trainer for data parallel Tensorflow training.

TensorflowConfig

PublicAPI (beta): This API is in beta and may change before becoming stable.

prepare_dataset_shard

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback

Keras callback for Ray Train reporting and checkpointing.

Horovod#

HorovodTrainer

A Trainer for data parallel Horovod training.

HorovodConfig

Configurations for Horovod setup.

XGBoost#

XGBoostTrainer

A Trainer for distributed data-parallel XGBoost training.

RayTrainReportCallback

XGBoost callback to save checkpoints and report metrics.

LightGBM#

LightGBMTrainer

A Trainer for data parallel LightGBM training.

RayTrainReportCallback

Creates a callback that reports metrics and checkpoints model.

Ray Train Configuration#

ScalingConfig

Configuration for scaling training.

RunConfig

Runtime configuration for training and tuning runs.

FailureConfig

Configuration related to failure handling of each training/tuning run.

CheckpointConfig

Configurable parameters for defining the checkpointing strategy.

DataConfig

Class responsible for configuring Train dataset preprocessing.

SyncConfig

Ray Train Utilities#

Classes

Checkpoint

A reference to data persisted as a directory in local or remote storage.

TrainContext

Context containing metadata that can be accessed within Ray Train workers.

Functions

get_checkpoint

Access the latest reported checkpoint to resume from if one exists.

get_context

Get or create a singleton training context.

get_dataset_shard

Returns the ray.data.DataIterator shard for this worker.

report

Report metrics and optionally save a checkpoint.

Ray Train Output#

Result

Ray Train Errors#

SessionMisuseError

Indicates a method or function was used outside of a session.

TrainingFailedError

An error indicating that training has failed.

Ray Train Developer APIs#

Trainer Base Classes#

BaseTrainer

Defines interface for distributed training on Ray.

DataParallelTrainer

A Trainer for data parallel training.

Train Backend Base Classes#

Backend

Singleton for distributed communication backend.

BackendConfig

Parent class for configurations of training backend.