ray.tune.TuneContext.get_local_rank#

TuneContext.get_local_rank() int[source]#

Get the local rank of this worker (rank of the worker on its node).

import torch

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_loop_per_worker(config):
    if torch.cuda.is_available():
        torch.cuda.set_device(train.get_context().get_local_rank())
    ...

train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
    datasets={"train": train_dataset}
)
trainer.fit()

PublicAPI (beta): This API is in beta and may change before becoming stable.

Warning

DEPRECATED: This API is deprecated and may be removed in future Ray releases. get_local_rank is deprecated for Ray Tune because there is no concept of worker ranks for Ray Tune, so these methods only make sense to use in the context of a Ray Train worker.