Ray Train API#

This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.

For core Ray AIR APIs, take a look at the AIR package reference.

Ray Train Base Classes (Developer APIs)#

Trainer Base Classes#

BaseTrainer(*args, **kwargs)

Defines interface for distributed training on Ray.

DataParallelTrainer(*args, **kwargs)

A Trainer for data parallel training.

GBDTTrainer(*args, **kwargs)

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

BaseTrainer Methods#

fit()

Runs training.

setup()

Called during fit() to perform initial setup on the Trainer.

preprocess_datasets()

Called during fit() to preprocess dataset attributes with preprocessor.

training_loop()

Loop called by fit() to run training and report results to Tune.

as_trainable()

Convert self to a tune.Trainable class.

Train Backend Base Classes#

Backend(*args, **kwargs)

Singleton for distributed communication backend.

BackendConfig()

Parent class for configurations of training backend.

Ray Train Integrations#

PyTorch#

TorchTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch training.

TorchConfig([backend, init_method, timeout_s])

Configuration for torch process group setup.

TorchCheckpoint([local_path, data_dict, uri])

A Checkpoint with Torch-specific functionality.

PyTorch Training Loop Utilities#

prepare_model(model[, move_to_device, ...])

Prepares the model for distributed execution.

prepare_optimizer(optimizer)

Wraps optimizer to support automatic mixed precision.

prepare_data_loader(data_loader[, ...])

Prepares DataLoader for distributed execution.

get_device()

Gets the correct torch device to use for training.

accelerate([amp])

Enables training optimizations.

backward(tensor)

Computes the gradient of the specified tensor w.r.t.

enable_reproducibility([seed])

Limits sources of nondeterministic behavior.

Tensorflow/Keras#

TensorflowTrainer(*args, **kwargs)

A Trainer for data parallel Tensorflow training.

TensorflowConfig()

PublicAPI (beta): This API is in beta and may change before becoming stable.

TensorflowCheckpoint(*args, **kwargs)

A Checkpoint with TensorFlow-specific functionality.

Tensorflow/Keras Training Loop Utilities#

prepare_dataset_shard(tf_dataset_shard)

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback([checkpoint_on, ...])

Keras callback for Ray AIR reporting and checkpointing.

Horovod#

HorovodTrainer(*args, **kwargs)

A Trainer for data parallel Horovod training.

HorovodConfig([nics, verbose, key, ...])

Configurations for Horovod setup.

XGBoost#

XGBoostTrainer(*args, **kwargs)

A Trainer for data parallel XGBoost training.

XGBoostCheckpoint([local_path, data_dict, uri])

A Checkpoint with XGBoost-specific functionality.

LightGBM#

LightGBMTrainer(*args, **kwargs)

A Trainer for data parallel LightGBM training.

LightGBMCheckpoint([local_path, data_dict, uri])

A Checkpoint with LightGBM-specific functionality.

HuggingFace#

HuggingFaceTrainer(*args, **kwargs)

A Trainer for data parallel HuggingFace Transformers on PyTorch training.

HuggingFaceCheckpoint([local_path, ...])

A Checkpoint with HuggingFace-specific functionality.

Scikit-Learn#

SklearnTrainer(*args, **kwargs)

A Trainer for scikit-learn estimator training.

SklearnCheckpoint([local_path, data_dict, uri])

A Checkpoint with sklearn-specific functionality.

Mosaic#

MosaicTrainer(*args, **kwargs)

A Trainer for data parallel Mosaic Composers on PyTorch training.

Reinforcement Learning (RLlib)#

RLTrainer(*args, **kwargs)

Reinforcement learning trainer.

RLCheckpoint([local_path, data_dict, uri])

A Checkpoint with RLlib-specific functionality.