ray.rllib.evaluation.rollout_worker.RolloutWorker#
- class ray.rllib.evaluation.rollout_worker.RolloutWorker(*, env_creator: Callable[[EnvContext], Any | gymnasium.Env | None], validate_env: Callable[[Any | gymnasium.Env, EnvContext], None] | None = None, config: AlgorithmConfig | None = None, worker_index: int = 0, num_workers: int | None = None, recreated_worker: bool = False, log_dir: str | None = None, spaces: Dict[str, Tuple[gymnasium.spaces.Space, gymnasium.spaces.Space]] | None = None, default_policy_class: Type[Policy] | None = None, dataset_shards: List[Dataset] | None = None, **kwargs)[source]#
Bases:
ParallelIteratorWorker
,EnvRunner
Common experience collection class.
This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training.
This class supports vectorized and multi-agent policy evaluation (e.g., VectorEnv, MultiAgentEnv, etc.)
# Create a rollout worker and using it to collect experiences. import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) print(worker.sample()) # Creating a multi-agent rollout worker from gymnasium.spaces import Discrete, Box import random MultiAgentTrafficGrid = ... worker = RolloutWorker( env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), config=AlgorithmConfig().multi_agent( policies={ # Use an ensemble of two policies for car agents "car_policy1": (PGTFPolicy, Box(...), Discrete(...), AlgorithmConfig.overrides(gamma=0.99)), "car_policy2": (PGTFPolicy, Box(...), Discrete(...), AlgorithmConfig.overrides(gamma=0.95)), # Use a single shared policy for all traffic lights "traffic_light_policy": (PGTFPolicy, Box(...), Discrete(...), {}), }, policy_mapping_fn=( lambda agent_id, episode, **kwargs: random.choice(["car_policy1", "car_policy2"]) if agent_id.startswith("car_") else "traffic_light_policy"), ), ) print(worker.sample())
SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]} ) MultiAgentBatch({ "car_policy1": SampleBatch(...), "car_policy2": SampleBatch(...), "traffic_light_policy": SampleBatch(...)} )
Methods
Initializes a RolloutWorker instance.
Adds a new policy to this RolloutWorker.
Calls the given function with this Actor instance.
Applies the given gradients to this worker's models.
Returns a gradient computed w.r.t the specified samples.
Returns the kwargs dict used to create this worker.
Finds a free port on the node that this worker runs on.
Calls the given function with the specified policy as first arg.
Calls the given function with each sub-environment as arg.
Calls given function with each sub-env plus env_ctx as args.
Calls the given function with each (policy, policy_id) tuple.
Calls the given function with each (policy, policy_id) tuple.
Returns a snapshot of filters.
Returns the current
self.global_vars
dict of this RolloutWorker.Returns the hostname of the process running this evaluator.
Returns the thus-far collected metrics from this worker's rollouts.
Returns the IP address of the node that this worker runs on.
Returns all policies-to-train, given an optional batch.
Return policy for the specified id, or None.
Returns each policies' model weights of this worker.
Update policies based on the given batch.
Locks this RolloutWorker via its own threading.Lock.
Creates the RLModule for this EnvRunner and assigns it to
self.module
.Implements ParallelIterator worker init.
Implements ParallelIterator worker item fetch.
Batches par_iter_next.
Iterates in increments of step starting from start.
Batches par_iter_slice.
Ping the actor.
Removes a policy from this RolloutWorker.
Returns a batch of experience sampled from this worker.
Sample and batch and learn on it.
Same as sample() but returns the count as a separate value.
Updates this worker's and all its policies' global vars.
Sets
self.is_policy_to_train()
to a new callable.Sets
self.policy_mapping_fn
to a new callable (if provided).Sets each policies' model weights of this worker.
Join a torch process group for distributed SGD.
Releases all resources used by this RolloutWorker.
Changes self's filter to given and rebases any accumulated delta.
Unlocks this RolloutWorker via its own threading.Lock.