ray.train.v2.api.context.TrainContext.get_node_rank#

TrainContext.get_node_rank() int[source]#

Get the rank of this node.

Example

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_loop_per_worker():
    print(train.get_context().get_node_rank())

trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=1),
)
trainer.fit()