Pattern: Using a supervisor actor to manage a tree of actors#

Actor supervision is a pattern in which a supervising actor manages a collection of worker actors. The supervisor delegates tasks to subordinates and handles their failures. This pattern simplifies the driver since it manages only a few supervisors and does not deal with failures from worker actors directly. Furthermore, multiple supervisors can act in parallel to parallelize more work.

../../_images/tree-of-actors.svg

Tree of actors#

Note

  • If the supervisor dies (or the driver), the worker actors are automatically terminated thanks to actor reference counting.

  • Actors can be nested to multiple levels to form a tree.

Example use case#

You want to do data parallel training and train the same model with different hyperparameters in parallel. For each hyperparameter, you can launch a supervisor actor to do the orchestration and it will create worker actors to do the actual training per data shard.

Note

For data parallel training and hyperparameter tuning, it’s recommended to use Ray Train (DataParallelTrainer and Ray Tune’s Tuner) which applies this pattern under the hood.

Code example#

import ray


@ray.remote(num_cpus=1)
class Trainer:
    def __init__(self, hyperparameter, data):
        self.hyperparameter = hyperparameter
        self.data = data

    # Train the model on the given training data shard.
    def fit(self):
        return self.data * self.hyperparameter


@ray.remote(num_cpus=1)
class Supervisor:
    def __init__(self, hyperparameter, data):
        self.trainers = [Trainer.remote(hyperparameter, d) for d in data]

    def fit(self):
        # Train with different data shard in parallel.
        return ray.get([trainer.fit.remote() for trainer in self.trainers])


data = [1, 2, 3]
supervisor1 = Supervisor.remote(1, data)
supervisor2 = Supervisor.remote(2, data)
# Train with different hyperparameters in parallel.
model1 = supervisor1.fit.remote()
model2 = supervisor2.fit.remote()
assert ray.get(model1) == [1, 2, 3]
assert ray.get(model2) == [2, 4, 6]