ray.rllib.evaluation.rollout_worker.RolloutWorker#

class ray.rllib.evaluation.rollout_worker.RolloutWorker(*, env_creator: Callable[[EnvContext], Any | gymnasium.Env | None], validate_env: Callable[[Any | gymnasium.Env, EnvContext], None] | None = None, config: AlgorithmConfig | None = None, worker_index: int = 0, num_workers: int | None = None, recreated_worker: bool = False, log_dir: str | None = None, spaces: Dict[str, Tuple[gymnasium.spaces.Space, gymnasium.spaces.Space]] | None = None, default_policy_class: Type[Policy] | None = None, dataset_shards: List[Dataset] | None = None, **kwargs)[source]#

Bases: ParallelIteratorWorker, EnvRunner

Common experience collection class.

This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training.

This class supports vectorized and multi-agent policy evaluation (e.g., VectorEnv, MultiAgentEnv, etc.)

# Create a rollout worker and using it to collect experiences.
import gymnasium as gym
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
worker = RolloutWorker(
  env_creator=lambda _: gym.make("CartPole-v1"),
  default_policy_class=PPOTF1Policy)
print(worker.sample())

# Creating a multi-agent rollout worker
from gymnasium.spaces import Discrete, Box
import random
MultiAgentTrafficGrid = ...
worker = RolloutWorker(
  env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
  config=AlgorithmConfig().multi_agent(
    policies={
      # Use an ensemble of two policies for car agents
      "car_policy1":
        (PGTFPolicy, Box(...), Discrete(...),
         AlgorithmConfig.overrides(gamma=0.99)),
      "car_policy2":
        (PGTFPolicy, Box(...), Discrete(...),
         AlgorithmConfig.overrides(gamma=0.95)),
      # Use a single shared policy for all traffic lights
      "traffic_light_policy":
        (PGTFPolicy, Box(...), Discrete(...), {}),
    },
    policy_mapping_fn=(
      lambda agent_id, episode, **kwargs:
      random.choice(["car_policy1", "car_policy2"])
      if agent_id.startswith("car_") else "traffic_light_policy"),
    ),
)
print(worker.sample())
SampleBatch({
    "obs": [[...]], "actions": [[...]], "rewards": [[...]],
    "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]}
)

MultiAgentBatch({
    "car_policy1": SampleBatch(...),
    "car_policy2": SampleBatch(...),
    "traffic_light_policy": SampleBatch(...)}
)

Methods

__init__

Initializes a RolloutWorker instance.

add_policy

Adds a new policy to this RolloutWorker.

apply

Calls the given function with this Actor instance.

apply_gradients

Applies the given gradients to this worker's models.

compute_gradients

Returns a gradient computed w.r.t the specified samples.

creation_args

Returns the kwargs dict used to create this worker.

find_free_port

Finds a free port on the node that this worker runs on.

for_policy

Calls the given function with the specified policy as first arg.

foreach_env

Calls the given function with each sub-environment as arg.

foreach_env_with_context

Calls given function with each sub-env plus env_ctx as args.

foreach_policy

Calls the given function with each (policy, policy_id) tuple.

foreach_policy_to_train

Calls the given function with each (policy, policy_id) tuple.

get_filters

Returns a snapshot of filters.

get_global_vars

Returns the current self.global_vars dict of this RolloutWorker.

get_host

Returns the hostname of the process running this evaluator.

get_metrics

Returns the thus-far collected metrics from this worker's rollouts.

get_node_ip

Returns the IP address of the node that this worker runs on.

get_policies_to_train

Returns all policies-to-train, given an optional batch.

get_policy

Return policy for the specified id, or None.

get_weights

Returns each policies' model weights of this worker.

learn_on_batch

Update policies based on the given batch.

lock

Locks this RolloutWorker via its own threading.Lock.

par_iter_init

Implements ParallelIterator worker init.

par_iter_next

Implements ParallelIterator worker item fetch.

par_iter_next_batch

Batches par_iter_next.

par_iter_slice

Iterates in increments of step starting from start.

par_iter_slice_batch

Batches par_iter_slice.

ping

Ping the actor.

remove_policy

Removes a policy from this RolloutWorker.

sample

Returns a batch of experience sampled from this worker.

sample_and_learn

Sample and batch and learn on it.

sample_with_count

Same as sample() but returns the count as a separate value.

set_global_vars

Updates this worker's and all its policies' global vars.

set_is_policy_to_train

Sets self.is_policy_to_train() to a new callable.

set_policy_mapping_fn

Sets self.policy_mapping_fn to a new callable (if provided).

set_weights

Sets each policies' model weights of this worker.

setup_torch_data_parallel

Join a torch process group for distributed SGD.

stop

Releases all resources used by this RolloutWorker.

sync_filters

Changes self's filter to given and rebases any accumulated delta.

unlock

Unlocks this RolloutWorker via its own threading.Lock.