Ray Train API#

PyTorch Ecosystem#

TorchTrainer

A Trainer for data parallel PyTorch training.

TorchConfig

Configuration for torch process group setup.

TorchXLAConfig

Configuration for torch XLA setup.

PyTorch#

get_device

Gets the correct torch device configured for this process.

get_devices

Gets the correct torch device list configured for this process.

prepare_model

Prepares the model for distributed execution.

prepare_data_loader

Prepares DataLoader for distributed execution.

enable_reproducibility

Limits sources of nondeterministic behavior.

PyTorch Lightning#

prepare_trainer

Prepare the PyTorch Lightning Trainer for distributed execution.

RayLightningEnvironment

Setup Lightning DDP training environment for Ray cluster.

RayDDPStrategy

Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

RayFSDPStrategy

Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.

RayDeepSpeedStrategy

Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

RayTrainReportCallback

A simple callback that reports checkpoints to Ray on train epoch end.

Hugging Face Transformers#

prepare_trainer

Prepare your HuggingFace Transformer Trainer for Ray Train.

RayTrainReportCallback

A simple callback to report checkpoints and metrics to Ray Train.

More Frameworks#

Tensorflow/Keras#

TensorflowTrainer

A Trainer for data parallel Tensorflow training.

TensorflowConfig

PublicAPI (beta): This API is in beta and may change before becoming stable.

prepare_dataset_shard

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback

Keras callback for Ray Train reporting and checkpointing.

Horovod#

HorovodTrainer

A Trainer for data parallel Horovod training.

HorovodConfig

Configurations for Horovod setup.

XGBoost#

XGBoostTrainer

A Trainer for data parallel XGBoost training.

RayTrainReportCallback

XGBoost callback to save checkpoints and report metrics.

LightGBM#

LightGBMTrainer

A Trainer for data parallel LightGBM training.

RayTrainReportCallback

Creates a callback that reports metrics and checkpoints model.

Ray Train Configuration#

CheckpointConfig

Configurable parameters for defining the checkpointing strategy.

DataConfig

Class responsible for configuring Train dataset preprocessing.

FailureConfig

Configuration related to failure handling of each training/tuning run.

RunConfig

Runtime configuration for training and tuning runs.

ScalingConfig

Configuration for scaling training.

SyncConfig

Configuration object for Train/Tune file syncing to RunConfig(storage_path).

Ray Train Utilities#

Classes

Checkpoint

A reference to data persisted as a directory in local or remote storage.

TrainContext

Context for Ray training executions.

Functions

get_checkpoint

Access the session's last checkpoint to resume from if applicable.

get_context

Get or create a singleton training context.

get_dataset_shard

Returns the ray.data.DataIterator shard for this worker.

report

Report metrics and optionally save a checkpoint.

Ray Train Output#

Result

The final result of a ML training run or a Tune trial.

Ray Train Errors#

SessionMisuseError

Indicates a method or function was used outside of a session.

TrainingFailedError

An error indicating that training has failed.

Ray Train Developer APIs#

Trainer Base Classes#

BaseTrainer

Defines interface for distributed training on Ray.

DataParallelTrainer

A Trainer for data parallel training.

Train Backend Base Classes#

Backend

Singleton for distributed communication backend.

BackendConfig

Parent class for configurations of training backend.