ray.train.TrainContext.get_local_rank#

abstract TrainContext.get_local_rank() int[source]#

Get the local rank of this worker (rank of the worker on its node).

import ray.train
from ray.train.torch import TorchTrainer

def train_fn_per_worker(config):
    if ray.train.get_context().get_local_rank() == 0:
        print("Local rank 0 worker")

trainer = TorchTrainer(
    train_fn_per_worker,
    scaling_config=ray.train.ScalingConfig(num_workers=2),
)
trainer.fit()