ray.train.context.TrainContext.get_world_size#

TrainContext.get_world_size() int[source]#

Get the current world size (i.e. total number of workers) for this run.

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

NUM_WORKERS = 2

def train_loop_per_worker(config):
    assert train.get_context().get_world_size() == NUM_WORKERS

train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TensorflowTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
    datasets={"train": train_dataset}
)
trainer.fit()

PublicAPI (beta): This API is in beta and may change before becoming stable.