Ray Train API#

This page covers framework specific integrations with Ray Train and Ray Train Developer APIs.

For core Ray AIR APIs, take a look at the AIR package reference.

Ray Train Base Classes (Developer APIs)#

Trainer Base Classes#

BaseTrainer(*args, **kwargs)

Defines interface for distributed training on Ray.

DataParallelTrainer(*args, **kwargs)

A Trainer for data parallel training.

GBDTTrainer(*args, **kwargs)

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

BaseTrainer API#

fit()

Runs training.

setup()

Called during fit() to perform initial setup on the Trainer.

preprocess_datasets()

Called during fit() to preprocess dataset attributes with preprocessor.

training_loop()

Loop called by fit() to run training and report results to Tune.

as_trainable()

Convert self to a tune.Trainable class.

Train Backend Base Classes#

Backend(*args, **kwargs)

Singleton for distributed communication backend.

BackendConfig()

Parent class for configurations of training backend.

Ray Train Integrations#

PyTorch#

TorchTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch training.

TorchConfig([backend, init_method, timeout_s])

Configuration for torch process group setup.

TorchCheckpoint([local_path, data_dict, uri])

A Checkpoint with Torch-specific functionality.

PyTorch Training Loop Utilities#

prepare_model(model[, move_to_device, ...])

Prepares the model for distributed execution.

prepare_optimizer(optimizer)

Wraps optimizer to support automatic mixed precision.

prepare_data_loader(data_loader[, ...])

Prepares DataLoader for distributed execution.

get_device()

Gets the correct torch device configured for this process.

accelerate([amp])

Enables training optimizations.

backward(tensor)

Computes the gradient of the specified tensor w.r.t.

enable_reproducibility([seed])

Limits sources of nondeterministic behavior.

PyTorch Lightning#

LightningTrainer(*args, **kwargs)

A Trainer for data parallel PyTorch Lightning training.

LightningConfigBuilder()

Configuration Class to pass into LightningTrainer.

LightningCheckpoint(*args, **kwargs)

A Checkpoint with Lightning-specific functionality.

Tensorflow/Keras#

TensorflowTrainer(*args, **kwargs)

A Trainer for data parallel Tensorflow training.

TensorflowConfig()

PublicAPI (beta): This API is in beta and may change before becoming stable.

TensorflowCheckpoint(*args, **kwargs)

A Checkpoint with TensorFlow-specific functionality.

Tensorflow/Keras Training Loop Utilities#

prepare_dataset_shard(tf_dataset_shard)

A utility function that overrides default config for Tensorflow Dataset.

ReportCheckpointCallback([checkpoint_on, ...])

Keras callback for Ray AIR reporting and checkpointing.

Horovod#

HorovodTrainer(*args, **kwargs)

A Trainer for data parallel Horovod training.

HorovodConfig([nics, verbose, key, ...])

Configurations for Horovod setup.

XGBoost#

XGBoostTrainer(*args, **kwargs)

A Trainer for data parallel XGBoost training.

XGBoostCheckpoint([local_path, data_dict, uri])

A Checkpoint with XGBoost-specific functionality.

LightGBM#

LightGBMTrainer(*args, **kwargs)

A Trainer for data parallel LightGBM training.

LightGBMCheckpoint([local_path, data_dict, uri])

A Checkpoint with LightGBM-specific functionality.

HuggingFace#

HuggingFaceTrainer(*args, **kwargs)

A Trainer for data parallel HuggingFace Transformers on PyTorch training.

HuggingFaceCheckpoint([local_path, ...])

A Checkpoint with HuggingFace-specific functionality.

Scikit-Learn#

SklearnTrainer(*args, **kwargs)

A Trainer for scikit-learn estimator training.

SklearnCheckpoint([local_path, data_dict, uri])

A Checkpoint with sklearn-specific functionality.

Mosaic#

MosaicTrainer(*args, **kwargs)

A Trainer for data parallel Mosaic Composers on PyTorch training.

Reinforcement Learning (RLlib)#

RLTrainer(*args, **kwargs)

Reinforcement learning trainer.

RLCheckpoint([local_path, data_dict, uri])

A Checkpoint with RLlib-specific functionality.

Ray Train Experiment Restoration#

train.trainer.BaseTrainer.restore(path[, ...])

Restores a Train experiment from a previously interrupted/failed run.

Note

All trainer classes have a restore method that takes in a path pointing to the directory of the experiment to be restored. restore also exposes a subset of construtor arguments that can be re-specified. See Restoration API for Built-in Trainers below for details on restore arguments for different AIR trainer integrations.

Restoration API for Built-in Trainers#

train.data_parallel_trainer.DataParallelTrainer.restore(path)

Restores a DataParallelTrainer from a previously interrupted/failed run.

train.huggingface.HuggingFaceTrainer.restore(path)

Restores a HuggingFaceTrainer from a previously interrupted/failed run.

Note

TorchTrainer.restore, TensorflowTrainer.restore, and HorovodTrainer.restore can take in the same parameters as their parent class’s DataParallelTrainer.restore.

Unless otherwise specified, other trainers will accept the same parameters as BaseTrainer.restore.

See also

See How do I restore a Ray Train experiment? for more details on when and how trainer restore should be used.