ray.train.TrainContext.get_local_rank#
- abstract TrainContext.get_local_rank() int [source]#
Get the local rank of this worker (rank of the worker on its node).
import ray.train from ray.train.torch import TorchTrainer def train_fn_per_worker(config): if ray.train.get_context().get_local_rank() == 0: print("Local rank 0 worker") trainer = TorchTrainer( train_fn_per_worker, scaling_config=ray.train.ScalingConfig(num_workers=2), ) trainer.fit()