ray.train.TrainContext.get_world_size#
- abstract TrainContext.get_world_size() int [source]#
Get the current world size (i.e. total number of workers) for this run.
import ray.train from ray.train.torch import TorchTrainer NUM_WORKERS = 2 def train_fn_per_worker(config): assert ray.train.get_context().get_world_size() == NUM_WORKERS trainer = TorchTrainer( train_fn_per_worker, scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS), ) trainer.fit()