Ray Train API

Trainer

class ray.train.Trainer(backend: Union[str, ray.train.backend.BackendConfig], num_workers: int, use_gpu: bool = False, resources_per_worker: Optional[Dict[str, float]] = None, logdir: Optional[str] = None, max_retries: int = 3)[source]

A class for enabling seamless distributed deep learning.

Directory structure: - A logdir is created during instantiation. This will hold all the results/checkpoints for the lifetime of the Trainer. By default, it will be of the form ~/ray_results/train_<datestring>. - A run_dir is created for each run call. This will hold the checkpoints and results for a single trainer.run() or trainer.run_iterator() call. It will be of the form run_<run_id>.

Parameters
  • backend (Union[str, BackendConfig]) – The backend used for distributed communication. If configurations are needed, a subclass of BackendConfig can be passed in. Supported str values: {“torch”, “tensorflow”, “horovod”}.

  • num_workers (int) – The number of workers (Ray actors) to launch. Each worker will reserve 1 CPU by default. The number of CPUs reserved by each worker can be overridden with the resources_per_worker argument.

  • use_gpu (bool) – If True, training will be done on GPUs (1 per worker). Defaults to False. The number of GPUs reserved by each worker can be overridden with the resources_per_worker argument.

  • resources_per_worker (Optional[Dict]) – If specified, the resources defined in this Dict will be reserved for each worker. The CPU and GPU keys (case-sensitive) can be defined to override the number of CPU/GPUs used by each worker.

  • logdir (Optional[str]) –

    Path to the file directory where logs

    should be persisted. If this is not specified, one will be generated.

    max_retries (int): Number of retries when Ray actors fail.

    Defaults to 3. Set to -1 for unlimited retries.

PublicAPI (beta): This API is in beta and may change before becoming stable.

create_logdir(log_dir: Optional[Union[str, pathlib.Path]]) pathlib.Path[source]

Create logdir for the Trainer.

create_run_dir()[source]

Create rundir for the particular training run.

start(initialization_hook: Optional[Callable[[], None]] = None)[source]

Starts the training execution service.

Parameters

initialization_hook (Optional[Callable]) – The function to call on each worker when it is instantiated.

run(train_func: Union[Callable[[], ray.train.trainer.T], Callable[[Dict[str, Any]], ray.train.trainer.T]], config: Optional[Dict[str, Any]] = None, callbacks: Optional[List[ray.train.callbacks.callback.TrainingCallback]] = None, dataset: Optional[Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]]]] = None, checkpoint: Optional[Union[Dict, str, pathlib.Path]] = None, checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy] = None) List[ray.train.trainer.T][source]

Runs a training function in a distributed manner.

Parameters
  • train_func (Callable) – The training function to execute. This can either take in no arguments or a config dict.

  • config (Optional[Dict]) – Configurations to pass into train_func. If None then an empty Dict will be created.

  • callbacks (Optional[List[TrainingCallback]]) – A list of Callbacks which will be executed during training. If this is not set, currently there are NO default Callbacks.

  • dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]) – Distributed Ray Dataset or DatasetPipeline to pass into the workers, which can be accessed from the training function via train.get_dataset_shard(). Sharding will automatically be handled by the Trainer. Multiple Datasets can be passed in as a Dict that maps each name key to a Dataset value, and each Dataset can be accessed from the training function by passing in a dataset_name argument to train.get_dataset_shard().

  • checkpoint (Optional[Dict|str|Path]) – The checkpoint data that should be loaded onto each worker and accessed by the training function via train.load_checkpoint(). If this is a str or Path then the value is expected to be a path to a file that contains a serialized checkpoint dict. If this is None then no checkpoint will be loaded.

  • checkpoint_strategy (Optional[CheckpointStrategy]) – The configurations for saving checkpoints.

Returns

A list of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

run_iterator(train_func: Union[Callable[[], ray.train.trainer.T], Callable[[Dict[str, Any]], ray.train.trainer.T]], config: Optional[Dict[str, Any]] = None, dataset: Optional[Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]]]] = None, checkpoint: Optional[Union[Dict, str, pathlib.Path]] = None, checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy] = None) TrainingIterator[source]

Same as run except returns an iterator over the results.

This is useful if you want to have more customization of what to do with the intermediate results or how to use the Trainer with Ray Tune.

def train_func(config):
    ...
    for _ in config["epochs"]:
        metrics = train()
        metrics = validate(...)
        ray.train.report(**metrics)
    return model

iterator = trainer.run_iterator(train_func, config=config)

for result in iterator:
    do_stuff(result)
    latest_ckpt = trainer.get_latest_checkpoint()

assert iterator.is_finished()
model = iterator.get_fin()[0]
Parameters
  • train_func (Callable) – The training function to execute. This can either take in no arguments or a config dict.

  • config (Optional[Dict]) – Configurations to pass into train_func. If None then an empty Dict will be created.

  • checkpoint (Optional[Dict|Path|str]) – The checkpoint data that should be loaded onto each worker and accessed by the training function via train.load_checkpoint(). If this is a str or Path then the value is expected to be a path to a file that contains a serialized checkpoint dict. If this is None then no checkpoint will be loaded.

  • checkpoint_strategy (Optional[CheckpointStrategy]) – The configurations for saving checkpoints.

Returns

An Iterator over the intermediate results from train.report().

property latest_run_dir: Optional[pathlib.Path]

Path to the log directory for the latest call to run().

Returns None if run() has not been called.

property latest_checkpoint_dir: Optional[pathlib.Path]

Path to the checkpoint directory.

Returns None if run() has not been called or if train.checkpoint() has not been called from train_func``within the most recent call to ``run.

property best_checkpoint_path: Optional[pathlib.Path]

Path to the best persisted checkpoint from the latest run.

“Best” is defined by the input CheckpointStrategy. Default behavior is to return the most recent checkpoint.

Returns None if run() has not been called or if train.save_checkpoint() has not been called from train_func within the most recent call to run.

property latest_checkpoint: Optional[Dict]

The latest saved checkpoint.

This checkpoint may not be saved to disk.

Returns None if run() has not been called or if train.checkpoint() has not been called from train_func.

property best_checkpoint: Optional[Dict]

Best saved checkpoint from the latest run.

“Best” is defined by the input CheckpointStrategy. Default behavior is to return the most recent checkpoint.

Returns None if run() has not been called or if train.save_checkpoint() has not been called from train_func within the most recent call to run.

static load_checkpoint_from_path(checkpoint_file_path: Union[str, pathlib.Path]) Dict[source]

Convenience method to load a checkpoint from path.

An error will be raised if the provided path does not exist.

Parameters

checkpoint_file_path (Union[str, Path]) – The path to the checkpoint to load. If the checkpoint saved in this path has not been created by Ray Train, there is no guarantee that it can be loaded in successfully.

shutdown()[source]

Shuts down the training execution service.

to_tune_trainable(train_func: Callable[[Dict[str, Any]], ray.train.trainer.T], dataset: Optional[Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]]]] = None) Type[ray.tune.trainable.Trainable][source]

Creates a Tune Trainable from the input training function.

Parameters
  • func (Callable) – The function that should be executed on each training worker.

  • dataset (Optional[Union[RayDataset, Dict[str, RayDataset]]]) – Distributed Ray p:ref:Dataset <dataset-api> or DatasetPipeline to pass into the workers, which can be accessed from the training function via train.get_dataset_shard(). Sharding will automatically be handled by the Trainer. Multiple Datasets can be passed in as a Dict that maps each name key to a Dataset value, and each Dataset can be accessed from the training function by passing in a dataset_name argument to train.get_dataset_shard().

Returns

A Trainable that can directly be passed into tune.run().

to_worker_group(train_cls: Type, *args, **kwargs) ray.train.trainer.TrainWorkerGroup[source]

Returns Ray actors with the provided class and the backend started.

This is useful if you want to provide your own class for training and have more control over execution, but still want to use Ray Train to setup the appropriate backend configurations (torch, tf, etc.).

class Trainer:
    def __init__(self, config):
        self.config = config

    def train_epoch(self):
        ...
        return 1

config = {"lr": 0.1}
trainer = Trainer(num_workers=2, backend="torch")
workers = trainer.to_worker_group(train_cls=Trainer, config=config)
futures = [w.train_epoch.remote() for w in workers]
assert ray.get(futures) == [1, 1]
assert ray.get(workers[0].train_epoch.remote()) == 1
workers.shutdown()
Parameters
  • train_cls (Type) – The class definition to use for the Ray actors/workers.

  • args – Arguments to pass into the __init__ of the provided train_cls.

  • kwargs – Arguments to pass into the __init__ of the provided train_cls.

TrainingIterator

class ray.train.TrainingIterator(backend_executor: Union[ray.train.backend.BackendExecutor, ray.train.utils.ActorWrapper], backend_config: ray.train.backend.BackendConfig, train_func: Union[Callable[[], ray.train.trainer.T], Callable[[Dict[str, Any]], ray.train.trainer.T]], dataset: Optional[Union[Dataset, DatasetPipeline, Dict[str, Union[Dataset, DatasetPipeline]]]], checkpoint_manager: ray.train.checkpoint.CheckpointManager, checkpoint: Optional[Union[Dict, str, pathlib.Path]], checkpoint_strategy: Optional[ray.train.checkpoint.CheckpointStrategy], run_dir: Optional[pathlib.Path] = None)[source]

An iterator over Train results. Returned by trainer.run_iterator. DeveloperAPI: This API may change across minor Ray releases.

get_final_results(force: bool = False) List[ray.train.trainer.T][source]

Gets the training func return values from each worker.

If force is True, then immediately finish training and return even if all the intermediate results have not been processed yet. Else, intermediate results must be processed before obtaining the final results. Defaults to False.

Backend Configurations

TorchConfig

class ray.train.torch.TorchConfig(backend: Optional[str] = None, init_method: str = 'env', timeout_s: int = 1800)[source]

Configuration for torch process group setup.

See https://pytorch.org/docs/stable/distributed.html for more info.

Parameters
  • backend (str) – The backend to use for training. See torch.distributed.init_process_group for more info and valid values. If set to None, nccl will be used if GPUs are requested, else gloo will be used.

  • init_method (str) – The initialization method to use. Either “env” for environment variable initialization or “tcp” for TCP initialization. Defaults to “env”.

  • timeout_s (int) – Seconds for process group operations to timeout.

PublicAPI (beta): This API is in beta and may change before becoming stable.

TensorflowConfig

class ray.train.tensorflow.TensorflowConfig[source]

PublicAPI (beta): This API is in beta and may change before becoming stable.

HorovodConfig

class ray.train.horovod.HorovodConfig(nics: Optional[Set[str]] = None, verbose: int = 1, key: Optional[str] = None, ssh_port: Optional[int] = None, ssh_identity_file: Optional[str] = None, ssh_str: Optional[str] = None, timeout_s: int = 300, placement_group_timeout_s: int = 100)[source]

Configurations for Horovod setup.

See https://github.com/horovod/horovod/blob/master/horovod/runner/common/util/settings.py # noqa: E501

Parameters
  • nics (Optional[Set[str]) – Network interfaces that can be used for communication.

  • verbose (int) – Horovod logging verbosity.

  • key (Optional[str]) – Secret used for communication between workers.

  • ssh_port (Optional[int]) – Port for SSH server running on worker nodes.

  • ssh_identity_file (Optional[str]) – Path to the identity file to ssh into different hosts on the cluster.

  • ssh_str (Optional[str]) – CAUTION WHEN USING THIS. Private key file contents. Writes the private key to ssh_identity_file.

  • timeout_s (int) – Timeout parameter for Gloo rendezvous.

  • placement_group_timeout_s (int) – Timeout parameter for Ray Placement Group creation. Currently unused.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Backend interfaces (for developers only)

Backend

class ray.train.backend.Backend(*args, **kwargs)[source]

Singleton for distributed communication backend.

share_cuda_visible_devices

If True, each worker process will have CUDA_VISIBLE_DEVICES set as the visible device IDs of all workers on the same node for this training instance. If False, each worker will have CUDA_VISIBLE_DEVICES set to the device IDs allocated by Ray for that worker.

Type

bool

DeveloperAPI: This API may change across minor Ray releases.

BackendConfig

class ray.train.backend.BackendConfig[source]

Parent class for configurations of training backend. DeveloperAPI: This API may change across minor Ray releases.

Callbacks

TrainingCallback

class ray.train.TrainingCallback[source]

Abstract Train callback class.

start_training(logdir: str, config: Dict, **info)[source]

Called once on training start.

Parameters
  • logdir (str) – Path to the file directory where logs should be persisted.

  • config (Dict) – The config dict passed into trainer.run().

  • **info – kwargs dict for forward compatibility.

process_results(results: List[Dict], **info)[source]

Called every time train.report() is called.

  1. Preprocesses results. Subclasses can implement preprocessing by defining a ResultsPreprocessor.

  2. Handles preprocessed results. Subclasses can implement handling by overriding the handle_result method.

Parameters
  • results (List[Dict]) – List of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

  • **info – kwargs dict for forward compatibility.

handle_result(results: List[Dict], **info)[source]

Called every time train.report() is called after preprocessing.

For more information, see process_results.

Parameters
  • results (List[Dict]) – List of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

  • **info – kwargs dict for forward compatibility.

finish_training(error: bool = False, **info)[source]

Called once after training is over.

Parameters
  • error (bool) – If True, there was an exception during training.

  • **info – kwargs dict for forward compatibility.

PrintCallback

class ray.train.callbacks.PrintCallback[source]

A callback that prints training results to STDOUT.

Example

>>> from ray import train
>>> from ray.train import Trainer
>>> from ray.train.callbacks import PrintCallback
>>> def train_func():
...    for i in range(2):
...        train.report(worker_idx=train.world_rank(), epoch=i)
>>> trainer = Trainer(num_workers=2, backend="torch")
>>> trainer.start()
>>> trainer.run(train_func, callbacks=[PrintCallback()])
[
    {
        "worker_idx": 0,
        "epoch": 0,
        "_timestamp": 1641232964,
        "_time_this_iter_s": 0.0021491050720214844,
        "_training_iteration": 1
    },
    {
        "worker_idx": 1,
        "epoch": 0,
        "_timestamp": 1641232964,
        "_time_this_iter_s": 0.0013790130615234375,
        "_training_iteration": 1
    }
]
[
    {
        "worker_idx": 0,
        "epoch": 1,
        "_timestamp": 1641232964,
        "_time_this_iter_s": 0.0025370121002197266,
        "_training_iteration": 2
    },
    {
        "worker_idx": 1,
        "epoch": 1,
        "_timestamp": 1641232964,
        "_time_this_iter_s": 0.002299070358276367,
        "_training_iteration": 2
    }
]

JsonLoggerCallback

class ray.train.callbacks.JsonLoggerCallback(logdir: Optional[str] = None, filename: Optional[str] = None, workers_to_log: Optional[Union[int, List[int]]] = 0)[source]

Logs Train results in json format.

Parameters
  • logdir (Optional[str]) – Path to directory where the results file should be. If None, will be set by the Trainer.

  • filename (Optional[str]) – Filename in logdir to save results to.

  • workers_to_log (int|List[int]|None) – Worker indices to log. If None, will log all workers. By default, will log the worker with index 0.

TBXLoggerCallback

class ray.train.callbacks.TBXLoggerCallback(logdir: Optional[str] = None, worker_to_log: int = 0)[source]

Logs Train results in TensorboardX format.

Parameters
  • logdir (Optional[str]) – Path to directory where the results file should be. If None, will be set by the Trainer.

  • worker_to_log (int) – Worker index to log. By default, will log the worker with index 0.

MLflowLoggerCallback

class ray.train.callbacks.MLflowLoggerCallback(tracking_uri: Optional[str] = None, registry_uri: Optional[str] = None, experiment_id: Optional[str] = None, experiment_name: Optional[str] = None, tags: Optional[Dict] = None, save_artifact: bool = False, logdir: Optional[str] = None, worker_to_log: int = 0)[source]

MLflow Logger to automatically log Train results and config to MLflow.

MLflow (https://mlflow.org) Tracking is an open source library for recording and querying experiments. This Ray Train callback sends information (config parameters, training results & metrics, and artifacts) to MLflow for automatic experiment tracking.

Parameters
  • tracking_uri (Optional[str]) – The tracking URI for where to manage experiments and runs. This can either be a local file path or a remote server. If None is passed in, the logdir of the trainer will be used as the tracking URI. This arg gets passed directly to mlflow initialization.

  • registry_uri (Optional[str]) – The registry URI that gets passed directly to mlflow initialization. If None is passed in, the logdir of the trainer will be used as the registry URI.

  • experiment_id (Optional[str]) – The experiment id of an already existing experiment. If not passed in, experiment_name will be used.

  • experiment_name (Optional[str]) – The experiment name to use for this Train run. If the experiment with the name already exists with MLflow, it will be used. If not, a new experiment will be created with this name. At least one of experiment_id or experiment_name must be passed in.

  • tags (Optional[Dict]) – An optional dictionary of string keys and values to set as tags on the run

  • save_artifact (bool) – If set to True, automatically save the entire contents of the Train local_dir as an artifact to the corresponding run in MlFlow.

  • logdir (Optional[str]) – Path to directory where the results file should be. If None, will be set by the Trainer. If no tracking uri or registry uri are passed in, the logdir will be used for both.

  • worker_to_log (int) – Worker index to log. By default, will log the worker with index 0.

TorchTensorboardProfilerCallback

class ray.train.callbacks.TorchTensorboardProfilerCallback(logdir: Optional[str] = None, workers_to_log: Optional[Union[int, List[int]]] = None)[source]

Synchronizes PyTorch Profiler traces onto disk.

This should typically be used in conjunction with TorchWorkerProfiler, though the actual requirement is for the _train_torch_profiler key to be populated in the results from train.report().

Parameters
  • logdir (Optional[str]) – The directory to store traces. If None, this will use a default temporary dir.

  • workers_to_log (Optional[int|List[int]]) – Worker indices to log. If None, will log all workers. By default, this will log all workers.

ResultsPreprocessors

ResultsPreprocessor

class ray.train.callbacks.results_preprocessors.ResultsPreprocessor[source]

Abstract class for preprocessing Train results. DeveloperAPI: This API may change across minor Ray releases.

abstract preprocess(results: List[Dict]) List[Dict][source]
Parameters

results (List[Dict]) – List of results from the training function. Each value in the list corresponds to the output of the training function from each worker.

Returns

A list of dictionaries. Each item in the list does not need to correspond to a single worker, and it is expected for the corresponding caller to understand the semantics of the preprocessed results.

SequentialResultsPreprocessor

class ray.train.callbacks.results_preprocessors.SequentialResultsPreprocessor(preprocessors: List[ray.train.callbacks.results_preprocessors.preprocessor.ResultsPreprocessor])[source]

A processor that sequentially runs a series of preprocessing steps.

  • preprocessors: [A, B, C]

  • preprocess: C.preprocess(B.preprocess(A.preprocess(results)

Parameters

preprocessors (List[ResultsPreprocessor]) – The preprocessors that will be run in sequence.

DeveloperAPI: This API may change across minor Ray releases.

IndexedResultsPreprocessor

class ray.train.callbacks.results_preprocessors.IndexedResultsPreprocessor(indices: Optional[Union[int, List[int]]])[source]

Preprocesses results by filtering by index.

Example:

  • indices: [0, 2]

  • input: [a, b, c, d]

  • output: [a, c]

Parameters

indices (Optional[int|List[int]]) – The indices of the results to return. If None, then all results will be returned (no-op).

DeveloperAPI: This API may change across minor Ray releases.

ExcludedKeysResultsPreprocessor

class ray.train.callbacks.results_preprocessors.ExcludedKeysResultsPreprocessor(excluded_keys: Optional[Iterable[str]] = None)[source]

Preprocesses each result dictionary by excluding specified keys.

Example:

  • excluded_keys: ["a"]

  • input: [{"a": 1, "b": 2}, {"a": 3, "b": 4}]

  • output: [{"b": 2}, {"b": 4}]

Parameters

excluded_keys (Optional[Iterable[str]]) – The keys to remove. If None then no keys will be removed.

DeveloperAPI: This API may change across minor Ray releases.

Checkpointing

CheckpointStrategy

class ray.train.CheckpointStrategy(num_to_keep: Optional[int] = None, checkpoint_score_attribute: str = '_timestamp', checkpoint_score_order: str = 'max')[source]

Configurable parameters for defining the Train checkpointing strategy.

Default behavior is to persist all checkpoints to disk. If num_to_keep is set, the default retention policy is to keep the checkpoints with maximum timestamp, i.e. the most recent checkpoints.

Parameters
  • num_to_keep (Optional[int]) – The number of checkpoints to keep on disk for this run. If a checkpoint is persisted to disk after there are already this many checkpoints, then an existing checkpoint will be deleted. If this is None then checkpoints will not be deleted. If this is 0 then no checkpoints will be persisted to disk.

  • checkpoint_score_attribute (str) – The attribute that will be used to score checkpoints to determine which checkpoints should be kept on disk when there are greater than num_to_keep checkpoints. This attribute must be a key from the checkpoint dictionary which has a numerical value.

  • checkpoint_score_order (str) – If “max”, then checkpoints with highest values of checkpoint_score_attribute will be kept. If “min”, then checkpoints with lowest values of checkpoint_score_attribute will be kept.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Training Function Utilities

train.report

ray.train.report(**kwargs) None[source]

Reports all keyword arguments to Train as intermediate results.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        train.report(hello="world")

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be reported by Train. If callbacks are provided, they are executed on these intermediate results.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.load_checkpoint

ray.train.load_checkpoint() Optional[Dict][source]

Loads checkpoint data onto the worker.

from ray import train

def train_func():
    checkpoint = train.load_checkpoint()
    for iter in range(checkpoint["epoch"], 5):
        print(iter)

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func, checkpoint={"epoch": 3})
# 3
# 4
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be checkpointed by Train.

Returns

The most recently saved checkpoint if train.save_checkpoint() has been called. Otherwise, the checkpoint that the session was originally initialized with. None if neither exist.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.save_checkpoint

ray.train.save_checkpoint(**kwargs) None[source]

Checkpoints all keyword arguments to Train as restorable state.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        train.save_checkpoint(epoch=iter)

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()
Parameters

**kwargs – Any key value pair to be checkpointed by Train.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.get_dataset_shard

ray.train.get_dataset_shard(dataset_name: Optional[str] = None) Optional[Union[Dataset, DatasetPipeline]][source]

Returns the Ray Dataset or DatasetPipeline shard for this worker.

You should call to_torch() or to_tf() on this shard to convert it to the appropriate framework-specific Dataset.

import ray
from ray import train

def train_func():
    model = Net()
    for iter in range(100):
        data_shard = train.get_dataset_shard().to_torch()
        model.train(data_shard)
    return model

dataset = ray.data.read_csv("train.csv")
dataset.filter(...).repeat().random_shuffle()

trainer = Trainer(backend="torch")
trainer.start()
# Trainer will automatically handle sharding.
train_model = trainer.run(train_func, dataset=dataset)
trainer.shutdown()
Parameters

dataset_name (Optional[str]) – If a Dictionary of Datasets was passed to Trainer, then specifies which dataset shard to return.

Returns

The Dataset or DatasetPipeline shard to use for this worker. If no dataset is passed into Trainer, then return None.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.world_rank

ray.train.world_rank() int[source]

Get the world rank of this worker.

import time
from ray import train

def train_func():
    for iter in range(100):
        time.sleep(1)
        if train.world_rank() == 0:
            print("Worker 0")

trainer = Trainer(backend="torch")
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.local_rank

ray.train.local_rank() int[source]

Get the local rank of this worker (rank of the worker on its node).

import time
from ray import train

def train_func():
    if torch.cuda.is_available():
        torch.cuda.set_device(train.local_rank())
    ...

trainer = Trainer(backend="torch", use_gpu=True)
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.world_size

ray.train.world_size() int[source]

Get the current world size (i.e. total number of workers) for this run.

import time
from ray import train

def train_func():
    assert train.world_size() == 4

trainer = Trainer(backend="torch", num_workers=4)
trainer.start()
trainer.run(train_func)
trainer.shutdown()

PublicAPI (beta): This API is in beta and may change before becoming stable.

PyTorch Training Function Utilities

train.torch.prepare_model

ray.train.torch.prepare_model(model: <Mock name='mock.nn.Module' id='140585439290832'>, move_to_device: bool = True, wrap_ddp: bool = True, ddp_kwargs: Optional[Dict[str, Any]] = None) <Mock name='mock.nn.Module' id='140585439290832'>[source]

Prepares the model for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

Parameters
  • model (torch.nn.Module) – A torch model to prepare.

  • move_to_device (bool) – Whether to move the model to the correct device. If set to False, the model needs to manually be moved to the correct device.

  • wrap_ddp (bool) – Whether to wrap models in DistributedDataParallel.

  • ddp_kwargs (Dict[str, Any]) – Args to pass into DistributedDataParallel initialization if wrap_ddp is set to True.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.prepare_data_loader

ray.train.torch.prepare_data_loader(data_loader: <Mock name='mock.utils.data.DataLoader' id='140585440231120'>, add_dist_sampler: bool = True, move_to_device: bool = True, auto_transfer: bool = True) <Mock name='mock.utils.data.DataLoader' id='140585440231120'>[source]

Prepares DataLoader for distributed execution.

This allows you to use the same exact code regardless of number of workers or the device type being used (CPU, GPU).

Parameters
  • data_loader (torch.utils.data.DataLoader) – The DataLoader to prepare.

  • add_dist_sampler (bool) – Whether to add a DistributedSampler to the provided DataLoader.

  • move_to_device (bool) – If set, automatically move the data returned by the data loader to the correct device.

  • auto_transfer (bool) – If set and device is GPU, another CUDA stream is created to automatically copy data from host (CPU) memory to device (GPU) memory (the default CUDA stream still runs the training procedure). If device is CPU, it will be disabled regardless of the setting. This configuration will be ignored if move_to_device is False.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.prepare_optimizer

ray.train.torch.prepare_optimizer(optimizer: <Mock name='mock.optim.Optimizer' id='140585451309008'>) <Mock name='mock.optim.Optimizer' id='140585451309008'>[source]

Wraps optimizer to support automatic mixed precision.

Parameters

optimizer (torch.optim.Optimizer) – The DataLoader to prepare.

Returns

A wrapped optimizer.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.backward

ray.train.torch.backward(tensor: <Mock name='mock.Tensor' id='140585439901776'>) None[source]

Computes the gradient of the specified tensor w.r.t. graph leaves.

Parameters

tensor (torch.Tensor) – Tensor of which the derivative will be computed.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.get_device

ray.train.torch.get_device() <Mock name='mock.device' id='140585440231184'>[source]

Gets the correct torch device to use for training. PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.enable_reproducibility

ray.train.torch.enable_reproducibility(seed: int = 0) None[source]

Limits sources of nondeterministic behavior.

This function:

  • Seeds PyTorch, Python, and NumPy.

  • Disables CUDA convolution benchmarking.

  • Configures PyTorch to use determinstic algorithms.

  • Seeds workers spawned for multi-process data loading.

Parameters

seed (int) – The number to seed libraries and data workers with.

Warning

train.torch.enable_reproducibility() can’t guarantee completely reproducible results across executions. To learn more, read the PyTorch notes on randomness.

train.torch.accelerate

ray.train.torch.accelerate(amp: bool = False) None[source]

Enables training optimizations.

Parameters

amp (bool) – If true, perform training with automatic mixed precision. Otherwise, use full precision.

Warning

train.torch.accelerate cannot be called more than once, and it must be called before any other train.torch utility function.

PublicAPI (beta): This API is in beta and may change before becoming stable.

train.torch.TorchWorkerProfiler

class ray.train.torch.TorchWorkerProfiler(trace_dir: Optional[str] = None)[source]

Utility class for running PyTorch Profiler on a Train worker.

Parameters

trace_dir (Optional[str]) – The directory to store traces on the worker node. If None, this will use a default temporary dir.

trace_handler(p: <Mock name='mock.profile' id='140585439290768'>)[source]

A stateful PyTorch Profiler trace handler.

This will the export chrome trace to a file on disk.

These exported traces can then be fetched by calling get_and_clear_profile_traces.

Parameters

p (profile) – A PyTorch Profiler profile.

get_and_clear_profile_traces()[source]

Reads unread Profiler traces from this worker.

Returns

The traces in a format consumable by TorchTensorboardProfilerCallback.

TensorFlow Training Function Utilities

train.tensorflow.prepare_dataset_shard

ray.train.tensorflow.prepare_dataset_shard(tf_dataset_shard: <Mock name='mock.data.Dataset' id='140585714295184'>)[source]

A utility function that disables Tensorflow autosharding.

This should be used on a TensorFlow Dataset created by calling to_tf() on a ray.data.Dataset returned by ray.train.get_dataset_shard() since the dataset has already been sharded across the workers.

Parameters

tf_dataset_shard (tf.data.Dataset) – A TensorFlow Dataset.

Returns

A TensorFlow Dataset with autosharding turned off.

PublicAPI (beta): This API is in beta and may change before becoming stable.