Ray Train API#

PyTorch Ecosystem#

TorchTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch training.

TorchConfig([backend, init_method, timeout_s])

Configuration for torch process group setup.

PyTorch#

get_device()

Gets the correct torch device configured for this process.

prepare_model(model[, move_to_device, ...])

Prepares the model for distributed execution.

prepare_data_loader(data_loader[, ...])

Prepares DataLoader for distributed execution.

enable_reproducibility([seed])

Limits sources of nondeterministic behavior.

PyTorch Lightning#

prepare_trainer(trainer)

Prepare the PyTorch Lightning Trainer for distributed execution.

RayLightningEnvironment(*args, **kwargs)

Setup Lightning DDP training environment for Ray cluster.

RayDDPStrategy(*args, **kwargs)

Subclass of DDPStrategy to ensure compatibility with Ray orchestration.

RayFSDPStrategy(*args, **kwargs)

Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.

RayDeepSpeedStrategy(*args, **kwargs)

Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.

RayTrainReportCallback(*args, **kwargs)

A simple callback that reports checkpoints to Ray on train epoch end.

Hugging Face Transformers#

prepare_trainer(trainer)

Prepare your HuggingFace Transformer Trainer for Ray Train.

RayTrainReportCallback(*args, **kwargs)

A simple callback to report checkpoints and metrics to Ray Tarin.

More Frameworks#

Tensorflow/Keras#

TensorflowTrainer(*args, **kwargs)

A Trainer for data parallel Tensorflow training.

TensorflowConfig()

PublicAPI (beta): This API is in beta and may change before becoming stable.

prepare_dataset_shard(tf_dataset_shard)

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback(*args, **kwargs)

Keras callback for Ray Train reporting and checkpointing.

Horovod#

HorovodTrainer(*args, **kwargs)

A Trainer for data parallel Horovod training.

HorovodConfig([nics, verbose, key, ...])

Configurations for Horovod setup.

XGBoost#

XGBoostTrainer(*args, **kwargs)

A Trainer for data parallel XGBoost training.

LightGBM#

LightGBMTrainer(*args, **kwargs)

A Trainer for data parallel LightGBM training.

Ray Train Configuration#

CheckpointConfig([num_to_keep, ...])

Configurable parameters for defining the checkpointing strategy.

DataConfig([datasets_to_split, ...])

Class responsible for configuring Train dataset preprocessing.

FailureConfig([max_failures, fail_fast])

Configuration related to failure handling of each training/tuning run.

RunConfig([name, storage_path, ...])

Runtime configuration for training and tuning runs.

ScalingConfig([trainer_resources, ...])

Configuration for scaling training.

SyncConfig([sync_period, sync_timeout, ...])

Configuration object for Train/Tune file syncing to RunConfig(storage_path).

Ray Train Utilities#

Classes

Checkpoint(path[, filesystem])

A reference to data persisted as a directory in local or remote storage.

TrainContext()

Context for Ray training executions.

Functions

get_checkpoint()

Access the session's last checkpoint to resume from if applicable.

get_context()

Get or create a singleton training context.

get_dataset_shard([dataset_name])

Returns the ray.data.DataIterator shard for this worker.

report(metrics, *[, checkpoint])

Report metrics and optionally save a checkpoint.

Ray Train Output#

Result(metrics, checkpoint, error[, ...])

The final result of a ML training run or a Tune trial.

Ray Train Developer APIs#

Trainer Base Classes#

BaseTrainer(*args, **kwargs)

Defines interface for distributed training on Ray.

DataParallelTrainer(*args, **kwargs)

A Trainer for data parallel training.

GBDTTrainer(*args, **kwargs)

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

Train Backend Base Classes#

Backend(*args, **kwargs)

Singleton for distributed communication backend.

BackendConfig()

Parent class for configurations of training backend.