ray.train.v2.api.context.TrainContext.get_world_rank#

TrainContext.get_world_rank() int[source]#

Get the world rank of this worker.

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def train_loop_per_worker(config):
    if train.get_context().get_world_rank() == 0:
        print("Worker 0")

trainer = TensorflowTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
)
trainer.fit()