Ray Train API#

PyTorch Ecosystem#

TorchTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch training.

TorchConfig([backend, init_method, timeout_s])

Configuration for torch process group setup.



Gets the correct torch device configured for this process.

prepare_model(model[, move_to_device, ...])

Prepares the model for distributed execution.

prepare_data_loader(data_loader[, ...])

Prepares DataLoader for distributed execution.


Limits sources of nondeterministic behavior.

PyTorch Lightning#


Prepare the PyTorch Lightning Trainer for distributed execution.

RayLightningEnvironment(*args, **kwargs)

Setup Lightning DDP training environment for Ray cluster.

RayDDPStrategy(*args, **kwargs)

Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

RayFSDPStrategy(*args, **kwargs)

Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.

RayDeepSpeedStrategy(*args, **kwargs)

Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

RayTrainReportCallback(*args, **kwargs)

A simple callback that reports checkpoints to Ray on train epoch end.

Hugging Face Transformers#


Prepare your HuggingFace Transformer Trainer for Ray Train.

RayTrainReportCallback(*args, **kwargs)

A simple callback to report checkpoints and metrics to Ray Tarin.

More Frameworks#


TensorflowTrainer(*args, **kwargs)

A Trainer for data parallel Tensorflow training.


PublicAPI (beta): This API is in beta and may change before becoming stable.


A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback(*args, **kwargs)

Keras callback for Ray Train reporting and checkpointing.


HorovodTrainer(*args, **kwargs)

A Trainer for data parallel Horovod training.

HorovodConfig([nics, verbose, key, ...])

Configurations for Horovod setup.


XGBoostTrainer(*args, **kwargs)

A Trainer for data parallel XGBoost training.


LightGBMTrainer(*args, **kwargs)

A Trainer for data parallel LightGBM training.

Ray Train Configuration#

CheckpointConfig([num_to_keep, ...])

Configurable parameters for defining the checkpointing strategy.

DataConfig([datasets_to_split, ...])

Class responsible for configuring Train dataset preprocessing.

FailureConfig([max_failures, fail_fast])

Configuration related to failure handling of each training/tuning run.

RunConfig([name, storage_path, ...])

Runtime configuration for training and tuning runs.

ScalingConfig([trainer_resources, ...])

Configuration for scaling training.

SyncConfig([sync_period, sync_timeout, ...])

Configuration object for Train/Tune file syncing to RunConfig(storage_path).

Ray Train Utilities#


Checkpoint(path[, filesystem])

A reference to data persisted as a directory in local or remote storage.


Context for Ray training executions.



Access the session's last checkpoint to resume from if applicable.


Get or create a singleton training context.


Returns the ray.data.DataIterator shard for this worker.

report(metrics, *[, checkpoint])

Report metrics and optionally save a checkpoint.

Ray Train Output#

Result(metrics, checkpoint, error[, ...])

The final result of a ML training run or a Tune trial.

Ray Train Developer APIs#

Trainer Base Classes#

BaseTrainer(*args, **kwargs)

Defines interface for distributed training on Ray.

DataParallelTrainer(*args, **kwargs)

A Trainer for data parallel training.

GBDTTrainer(*args, **kwargs)

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

Train Backend Base Classes#

Backend(*args, **kwargs)

Singleton for distributed communication backend.


Parent class for configurations of training backend.