ray.train.TrainContext.get_world_size#

abstract TrainContext.get_world_size() int[source]#

Get the current world size (i.e. total number of workers) for this run.

import ray.train
from ray.train.torch import TorchTrainer

NUM_WORKERS = 2

def train_fn_per_worker(config):
    assert ray.train.get_context().get_world_size() == NUM_WORKERS

trainer = TorchTrainer(
    train_fn_per_worker,
    scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS),
)
trainer.fit()