ray.train.context.TrainContext.get_local_rank#
- TrainContext.get_local_rank() int [source]#
Get the local rank of this worker (rank of the worker on its node).
import torch import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): if torch.cuda.is_available(): torch.cuda.set_device(train.get_context().get_local_rank()) ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2, use_gpu=True), datasets={"train": train_dataset} ) trainer.fit()
PublicAPI (beta): This API is in beta and may change before becoming stable.