ray.train.context.TrainContext.get_world_size#
- TrainContext.get_world_size() int [source]#
Get the current world size (i.e. total number of workers) for this run.
import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer NUM_WORKERS = 2 def train_loop_per_worker(config): assert train.get_context().get_world_size() == NUM_WORKERS train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=NUM_WORKERS), datasets={"train": train_dataset} ) trainer.fit()
PublicAPI (beta): This API is in beta and may change before becoming stable.