ray.train.context.TrainContext.get_local_rank#

TrainContext.get_local_rank() int[source]#

Get the local rank of this worker (rank of the worker on its node).

import torch

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_loop_per_worker(config):
    if torch.cuda.is_available():
        torch.cuda.set_device(train.get_context().get_local_rank())
    ...

train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
    datasets={"train": train_dataset}
)
trainer.fit()

PublicAPI (beta): This API is in beta and may change before becoming stable.