Ray Train API

Trainer

class ray.train.Trainer(backend: Union[str, ray.train.backend.BackendConfig], num_workers: int, use_gpu: bool = False, resources_per_worker: Optional[Dict[str, float]] = None, logdir: Optional[str] = None, max_retries: int = 3)[source]

A class for enabling seamless distributed deep learning.

Directory structure: - A logdir is created during instantiation. This will hold all the results/checkpoints for the lifetime of the Trainer. By default, it will be of the form ~/ray_results/train_<datestring>. - A run_dir is created for each run call. This will hold the checkpoints and results for a single trainer.run() or trainer.run_iterator() call. It will be of the form run_<run_id>.

Parameters
  • backend (Union[str, BackendConfig]) – The backend used for distributed communication. If configurations are needed, a subclass of BackendConfig can be passed in. Supported str values: {“torch”, “tensorflow”, “horovod”}.

  • num_workers (int) – The number of workers (Ray actors) to launch. Each worker will reserve 1 CPU by default. The number of CPUs reserved by each worker can be overridden with the resources_per_worker argument.

  • use_gpu (bool) – If True, training will be done on GPUs (1 per worker). Defaults to False. The number of GPUs reserved by each worker can be overridden with the resources_per_worker argument.

  • resources_per_worker (Optional[Dict]) – If specified, the resources defined in this Dict will be reserved for each worker. The CPU and GPU keys (case-sensitive) can be defined to override the number of CPU/GPUs used by each worker.

  • logdir (Optional[str]) –

    Path to the file directory where logs

    should be persisted. If this is not specified, one will be generated.

    max_retries (int): Number of retries when Ray actors fail.

    Defaults to 3. Set to -1 for unlimited retries.

PublicAPI (beta): This API is in beta and may change before becoming stable.

create_logdir(log_dir: Union[str, pathlib.Path, None]) → pathlib.Path[source]

Create logdir for the Trainer.

create_run_dir()[source]

Create rundir for the particular training run.

start(initialization_hook: Optional[Callable[[], None]] = None)[source]

Starts the training execution service.

Parameters

initialization_hook (Optional[Callable]) – The function to call on each worker when it is instantiated.

run(train_func: Union[Callable[], T], Callable[[Dict[str, Any]], T]], config: Optional[Dict[str, Any]] = None, callbacks: Optional[List[ray.train.callbacks.callback.TrainingCallback]] = None, dataset: Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]], None] = None, checkpoint: Union[Dict, str, pathlib.Path, None] = None, checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy] = None) → List[T][source]

Runs a training function in a distributed manner.

Parameters
  • train_func (Callable) – The training function to execute. This can either take in no arguments or a config dict.

  • config (Optional[Dict]) – Configurations to pass into train_func. If None then an empty Dict will be created.

  • callbacks (Optional[List[TrainingCallback]]) – A list of Callbacks which will be executed during training. If this is not set, currently there are NO default Callbacks.

  • dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]) – Distributed Ray Dataset or DatasetPipeline to pass into the workers, which can be accessed from the training function via train.get_dataset_shard(). Sharding will automatically be handled by the Trainer. Multiple Datasets can be passed in as a Dict that maps each name key to a Dataset value, and each Dataset can be accessed from the training function by passing in a dataset_name argument to train.get_dataset_shard().

  • checkpoint (Optional[Dict|str|Path]) – The checkpoint data that should be loaded onto each worker and accessed by the training function via train.load_checkpoint(). If this is a str or Path then the value is expected to be a path to a file that contains a serialized checkpoint dict. If this is None then no checkpoint will be loaded.

  • checkpoint_strategy (Optional[CheckpointStrategy]) – The configurations for saving checkpoints.

Returns

A list of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

run_iterator(train_func: Union[Callable[], T], Callable[[Dict[str, Any]], T]], config: Optional[Dict[str, Any]] = None, dataset: Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]], None] = None, checkpoint: Union[Dict, str, pathlib.Path, None] = None, checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy] = None) → TrainingIterator[source]

Same as run except returns an iterator over the results.

This is useful if you want to have more customization of what to do with the intermediate results or how to use the Trainer with Ray Tune.

def train_func(config):
    ...
    for _ in config["epochs"]:
        metrics = train()
        metrics = validate(...)
        ray.train.report(**metrics)
    return model

iterator = trainer.run_iterator(train_func, config=config)

for result in iterator:
    do_stuff(result)
    latest_ckpt = trainer.get_latest_checkpoint()

assert iterator.is_finished()
model = iterator.get_fin()[0]
Parameters
  • train_func (Callable) – The training function to execute. This can either take in no arguments or a config dict.

  • config (Optional[Dict]) – Configurations to pass into train_func. If None then an empty Dict will be created.

  • checkpoint (Optional[Dict|Path|str]) – The checkpoint data that should be loaded onto each worker and accessed by the training function via train.load_checkpoint(). If this is a str or Path then the value is expected to be a path to a file that contains a serialized checkpoint dict. If this is None then no checkpoint will be loaded.

  • checkpoint_strategy (Optional[CheckpointStrategy]) – The configurations for saving checkpoints.

Returns

An Iterator over the intermediate results from train.report().

property latest_run_dir

Path to the log directory for the latest call to run().

Returns None if run() has not been called.

property latest_checkpoint_dir

Path to the checkpoint directory.

Returns None if run() has not been called or if train.checkpoint() has not been called from train_func``within the most recent call to ``run.

property best_checkpoint_path

Path to the best persisted checkpoint from the latest run.

“Best” is defined by the input CheckpointStrategy. Default behavior is to return the most recent checkpoint.

Returns None if run() has not been called or if train.checkpoint() has not been called from train_func within the most recent call to run.

property latest_checkpoint

The latest saved checkpoint.

This checkpoint may not be saved to disk.

Returns None if run() has not been called or if train.checkpoint() has not been called from train_func.

shutdown()[source]

Shuts down the training execution service.

to_tune_trainable(train_func: Callable[[Dict[str, Any]], T], dataset: Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]], None] = None) → Type[ray.tune.trainable.Trainable][source]

Creates a Tune Trainable from the input training function.

Parameters
  • func (Callable) – The function that should be executed on each training worker.

  • dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]) – Distributed Ray p:ref:Dataset <dataset-api> or DatasetPipeline to pass into the workers, which can be accessed from the training function via train.get_dataset_shard(). Sharding will automatically be handled by the Trainer. Multiple Datasets can be passed in as a Dict that maps each name key to a Dataset value, and each Dataset can be accessed from the training function by passing in a dataset_name argument to train.get_dataset_shard().

Returns

A Trainable that can directly be passed into tune.run().

to_worker_group(train_cls: Type, *args, **kwargs) → ray.train.trainer.TrainWorkerGroup[source]

Returns Ray actors with the provided class and the backend started.

This is useful if you want to provide your own class for training and have more control over execution, but still want to use Ray Train to setup the appropriate backend configurations (torch, tf, etc.).

class Trainer:
    def __init__(self, config):
        self.config = config

    def train_epoch(self):
        ...
        return 1

config = {"lr": 0.1}
trainer = Trainer(num_workers=2, backend="torch")
workers = trainer.to_worker_group(train_cls=Trainer, config=config)
futures = [w.train_epoch.remote() for w in workers]
assert ray.get(futures) == [1, 1]
assert ray.get(workers[0].train_epoch.remote()) == 1
workers.shutdown()
Parameters
  • train_cls (Type) – The class definition to use for the Ray actors/workers.

  • kwargs (args,) – Arguments to pass into the __init__ of the provided train_cls.

TrainingIterator

class ray.train.TrainingIterator(backend_executor_actor: ray.actor.ActorHandle, backend_config: ray.train.backend.BackendConfig, train_func: Union[Callable[], T], Callable[[Dict[str, Any]], T]], run_dir: pathlib.Path, dataset: Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]], None], checkpoint_manager: ray.train.checkpoint.CheckpointManager, checkpoint: Union[Dict, str, pathlib.Path, None], checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy])[source]

An iterator over Train results. Returned by trainer.run_iterator. DeveloperAPI: This API may change across minor Ray releases.

get_final_results(force: bool = False) → List[T][source]

Gets the training func return values from each worker.

If force is True, then immediately finish training and return even if all the intermediate results have not been processed yet. Else, intermediate results must be processed before obtaining the final results. Defaults to False.

Backend Configurations

TorchConfig

class ray.train.torch.TorchConfig(backend: Optional[str] = None, init_method: str = 'env', timeout_s: int = 1800)[source]

Configuration for torch process group setup.

See https://pytorch.org/docs/stable/distributed.html for more info.

Parameters
  • backend (str) – The backend to use for training. See torch.distributed.init_process_group for more info and valid values. If set to None, nccl will be used if GPUs are requested, else gloo will be used.

  • init_method (str) – The initialization method to use. Either “env” for environment variable initialization or “tcp” for TCP initialization. Defaults to “env”.

  • timeout_s (int) – Seconds for process group operations to timeout.

PublicAPI (beta): This API is in beta and may change before becoming stable.

TensorflowConfig

class ray.train.tensorflow.TensorflowConfig[source]

PublicAPI (beta): This API is in beta and may change before becoming stable.

HorovodConfig

class ray.train.horovod.HorovodConfig(nics: Optional[Set[str]] = None, verbose: int = 1, key: Optional[str] = None, ssh_port: Optional[int] = None, ssh_identity_file: Optional[str] = None, ssh_str: Optional[str] = None, timeout_s: int = 300, placement_group_timeout_s: int = 100)[source]

Configurations for Horovod setup.

See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501

Parameters
  • nics (Optional[Set[str]) – Network interfaces that can be used for communication.

  • verbose (int) – Horovod logging verbosity.

  • key (Optional[str]) – Secret used for communication between workers.

  • ssh_port (Optional[int]) – Port for SSH server running on worker nodes.

  • ssh_identity_file (Optional[str]) – Path to the identity file to ssh into different hosts on the cluster.

  • ssh_str (Optional[str]) – CAUTION WHEN USING THIS. Private key file contents. Writes the private key to ssh_identity_file.

  • timeout_s (int) – Timeout parameter for Gloo rendezvous.

  • placement_group_timeout_s (int) – Timeout parameter for Ray Placement Group creation. Currently unused.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Callbacks

TrainingCallback

class ray.train.TrainingCallback[source]

Abstract Train callback class.

handle_result(results: List[Dict], **info)[source]

Called every time train.report() is called.

Parameters
  • results (List[Dict]) – List of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

  • **info – kwargs dict for forward compatibility.

start_training(logdir: str, **info)[source]

Called once on training start.

Parameters
  • logdir (str) – Path to the file directory where logs should be persisted.

  • **info – kwargs dict for forward compatibility.

finish_training(error: bool = False, **info)[source]

Called once after training is over.

Parameters
  • error (bool) – If True, there was an exception during training.

  • **info – kwargs dict for forward compatibility.

JsonLoggerCallback

class ray.train.callbacks.JsonLoggerCallback(logdir: Optional[str] = None, filename: Optional[str] = None, workers_to_log: Union[int, List[int], None] = 0)[source]

Logs Train results in json format.

Parameters
  • logdir (Optional[str]) – Path to directory where the results file should be. If None, will be set by the Trainer.

  • filename (Optional[str]) – Filename in logdir to save results to.

  • workers_to_log (int|List[int]|None) – Worker indices to log. If None, will log all workers. By default, will log the worker with index 0.

TBXLoggerCallback

class ray.train.callbacks.TBXLoggerCallback(logdir: Optional[str] = None, worker_to_log: int = 0)[source]

Logs Train results in TensorboardX format.

Parameters
  • logdir (Optional[str]) – Path to directory where the results file should be. If None, will be set by the Trainer.

  • worker_to_log (int) – Worker index to log. By default, will log the worker with index 0.

Checkpointing

CheckpointStrategy

class ray.train.CheckpointStrategy(num_to_keep: Optional[int] = None, checkpoint_score_attribute: str = '_timestamp', checkpoint_score_order: str = 'max')[source]

Configurable parameters for defining the Train checkpointing strategy.

Default behavior is to persist all checkpoints to disk. If num_to_keep is set, the default retention policy is to keep the checkpoints with maximum timestamp, i.e. the most recent checkpoints.

Parameters
  • num_to_keep (Optional[int]) – The number of checkpoints to keep on disk for this run. If a checkpoint is persisted to disk after there are already this many checkpoints, then an existing checkpoint will be deleted. If this is None then checkpoints will not be deleted. If this is 0 then no checkpoints will be persisted to disk.

  • checkpoint_score_attribute (str) – The attribute that will be used to score checkpoints to determine which checkpoints should be kept on disk when there are greater than num_to_keep checkpoints. This attribute must be a key from the checkpoint dictionary which has a numerical value.

  • checkpoint_score_order (str) – If “max”, then checkpoints with highest values of checkpoint_score_attribute will be kept. If “min”, then checkpoints with lowest values of checkpoint_score_attribute will be kept.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Training Function Utilities

train.report

ray.train.report(**kwargs) → None[source]

Reports all keyword arguments to Train as intermediate results.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        train.report(hello="world")

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be reported by Train. If callbacks are provided, they are executed on these intermediate results.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.load_checkpoint

ray.train.load_checkpoint() → Optional[Dict][source]

Loads checkpoint data onto the worker.

from ray import train

def train_func():
    checkpoint = train.load_checkpoint()
    for iter in range(checkpoint["epoch"], 5):
        print(iter)

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func, checkpoint={"epoch": 3})
# 3
# 4
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be checkpointed by Train.

Returns

The most recently saved checkpoint if train.save_checkpoint() has been called. Otherwise, the checkpoint that the session was originally initialized with. None if neither exist.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.save_checkpoint

ray.train.save_checkpoint(**kwargs) → None[source]

Checkpoints all keyword arguments to Train as restorable state.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        train.save_checkpoint(epoch=iter)

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be checkpointed by Train.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.world_rank

ray.train.world_rank() → int[source]

Get the world rank of this worker.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        if train.world_rank() == 0:
            print("Worker 0")

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.local_rank

ray.train.local_rank() → int[source]

Get the local rank of this worker (rank of the worker on its node).

import time
from ray import train

def train_func():
    if torch.cuda.is_available():
        torch.cuda.set_device(train.local_rank())
    ...

trainer = Trainer(backend="torch", use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.world_size

ray.train.world_size() → int[source]

Get the current world size (i.e. total number of workers) for this run.

import time
from ray import train

def train_func():
    assert train.world_size() == 4

trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

PyTorch Training Function Utilities

train.torch.prepare_model

ray.train.torch.prepare_model(model: <Mock name='mock.nn.Module' id='139804337267920'>, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None) → <Mock name=’mock.nn.Module’ id=’139804337267920’>[source]

Prepares the model for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

Parameters
  • model (torch.nn.Module) – A torch model to prepare.

  • move_to_device (bool) – Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device.

  • wrap_ddp (bool) – Whether to wrap models in DistributedDataParallel.

  • ddp_kwargs (Dict[str, Any]) – Args to pass into DistributedDataParallel initialization if wrap_ddp is set to True.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.prepare_data_loader

ray.train.torch.prepare_data_loader(data_loader: <Mock name='mock.utils.data.DataLoader' id='139804337268432'>, add_dist_sampler: bool = True, move_to_device: bool = True) → <Mock name=’mock.utils.data.DataLoader’ id=’139804337268432’>[source]

Prepares DataLoader for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

Parameters
  • data_loader (torch.utils.data.DataLoader) – The DataLoader to prepare.

  • add_dist_sampler (bool) – Whether to add a DistributedSampler to the provided DataLoader.

  • move_to_device (bool) – If set, automatically move the data returned by the data loader to the correct device.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.get_device

ray.train.torch.get_device() → <Mock name=’mock.device’ id=’139805033204688’>[source]

Gets the correct torch device to use for training.