Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.

See here for more details on how to use the new API stack.

Sampling the Environment or offline data#

Data ingest via either environment rollouts or other data-generating methods (e.g. reading from offline files) is done in RLlib by EnvRunner instances, which sit inside a EnvRunnerGroup (together with other parallel EnvRunners) in the RLlib Algorithm (under the self.env_runner_group property):

../../_images/rollout_worker_class_overview.svg

A typical RLlib EnvRunnerGroup setup inside an RLlib Algorithm: Each EnvRunnerGroup contains exactly one local EnvRunner object and N ray remote EnvRunner (Ray actors). The workers contain a policy map (with one or more policies), and - in case a simulator (env) is available - a vectorized BaseEnv (containing M sub-environments) and a SamplerInput (either synchronous or asynchronous) which controls the environment data collection loop. In the online case (i.e. environment is available) as well as the offline case (i.e. no environment), Algorithm uses the sample() method to get SampleBatch objects for training.#

RolloutWorker API#

Constructor#

RolloutWorker

Common experience collection class.

Multi agent#

add_policy

Adds a new policy to this RolloutWorker.

remove_policy

Removes a policy from this RolloutWorker.

get_policy

Return policy for the specified id, or None.

set_is_policy_to_train

Sets self.is_policy_to_train() to a new callable.

set_policy_mapping_fn

Sets self.policy_mapping_fn to a new callable (if provided).

for_policy

Calls the given function with the specified policy as first arg.

foreach_policy

Calls the given function with each (policy, policy_id) tuple.

foreach_policy_to_train

Calls the given function with each (policy, policy_id) tuple.

Setter and getter methods#

get_filters

Returns a snapshot of filters.

get_global_vars

Returns the current self.global_vars dict of this RolloutWorker.

set_global_vars

Updates this worker's and all its policies' global vars.

get_host

Returns the hostname of the process running this evaluator.

get_metrics

Returns the thus-far collected metrics from this worker's rollouts.

get_node_ip

Returns the IP address of the node that this worker runs on.

get_weights

Returns each policies' model weights of this worker.

set_weights

Sets each policies' model weights of this worker.

get_state

set_state

Threading#

lock

Locks this RolloutWorker via its own threading.Lock.

unlock

Unlocks this RolloutWorker via its own threading.Lock.

Sampling API#

sample

Returns a batch of experience sampled from this worker.

sample_with_count

Same as sample() but returns the count as a separate value.

sample_and_learn

Sample and batch and learn on it.

Training API#

learn_on_batch

Update policies based on the given batch.

setup_torch_data_parallel

Join a torch process group for distributed SGD.

compute_gradients

Returns a gradient computed w.r.t the specified samples.

apply_gradients

Applies the given gradients to this worker's models.

Environment API#

foreach_env

Calls the given function with each sub-environment as arg.

foreach_env_with_context

Calls given function with each sub-env plus env_ctx as args.

Miscellaneous#

stop

Releases all resources used by this RolloutWorker.

apply

Calls the given function with this Actor instance.

sync_filters

Changes self's filter to given and rebases any accumulated delta.

find_free_port

Finds a free port on the node that this worker runs on.

creation_args

Returns the kwargs dict used to create this worker.

assert_healthy

Checks that self.__init__() has been completed properly.

EnvRunner API#

EnvRunner

Base class for distributed RL-style data collection from an environment.

EnvRunnerGroup API#

Constructor#

EnvRunnerGroup

Set of EnvRunners with n @ray.remote workers and zero or one local worker.

EnvRunnerGroup.stop

Calls stop on all rollout workers (including the local one).

EnvRunnerGroup.reset

Hard overrides the remote EnvRunners in this set with the provided ones.

Worker Orchestration#

add_workers

Creates and adds a number of remote workers to this worker set.

foreach_worker

Calls the given function with each EnvRunner as its argument.

foreach_worker_with_id

Calls the given function with each EnvRunner and its ID as its arguments.

foreach_worker_async

Calls the given function asynchronously with each worker as the argument.

fetch_ready_async_reqs

Get esults from outstanding asynchronous requests that are ready.

num_in_flight_async_reqs

Returns the number of in-flight async requests.

local_worker

remote_workers

num_healthy_remote_workers

Returns the number of healthy remote workers.

num_healthy_workers

Returns the number of all healthy workers, including the local worker.

num_remote_worker_restarts

Total number of times managed remote workers have been restarted.

probe_unhealthy_workers

Checks for unhealthy workers and tries restoring their states.

Pass-through methods#

add_policy

Adds a policy to this EnvRunnerGroup's workers or a specific list of workers.

foreach_env

Calls func with all workers' sub-environments as args.

foreach_env_with_context

Calls func with all workers' sub-environments and env_ctx as args.

foreach_policy

Calls func with each worker's (policy, PolicyID) tuple.

foreach_policy_to_train

Apply func to all workers' Policies iff in policies_to_train.

sync_weights

Syncs model weights from the given weight source to all remote workers.

Sampler API#

InputReader instances are used to collect and return experiences from the envs. For more details on InputReader used for offline RL (e.g. reading files of pre-recorded data), see the offline RL API reference here.

Input Reader API#

InputReader

API for collecting and returning experiences during policy evaluation.

InputReader.next

Returns the next batch of read experiences.

Input Sampler API#

SamplerInput

Reads input experiences from an existing sampler.

SamplerInput.get_data

Called by self.next() to return the next batch of data.

SamplerInput.get_extra_batches

Returns list of extra batches since the last call to this method.

SamplerInput.get_metrics

Returns list of episode metrics since the last call to this method.

Synchronous Sampler API#

SyncSampler

Sync SamplerInput that collects experiences when get_data() is called.

Offline Sampler API#

The InputReader API is used by an individual RolloutWorker to produce batches of experiences either from an simulator or from an offline source (e.g. a file).

Here are some example extentions of the InputReader API:

JSON reader API#

JsonReader

Reader object that loads experiences from JSON file chunks.

JsonReader.read_all_files

Reads through all files and yields one SampleBatchType per line.

Mixed input reader#

MixedInput

Mixes input from a number of other input sources.

D4RL reader#

D4RLReader

Reader object that loads the dataset from the D4RL dataset.

IOContext#

IOContext

Class containing attributes to pass to input/output class constructors.

IOContext.default_sampler_input

Returns the RolloutWorker's SamplerInput object, if any.

Policy Map API#

PolicyMap

Maps policy IDs to Policy objects.

PolicyMap.items

Iterates over all policies, even the stashed ones.

PolicyMap.keys

Returns all valid keys, even the stashed ones.

PolicyMap.values

Returns all valid values, even the stashed ones.

Sample batch API#

SampleBatch

Wrapper around a dictionary with string keys and array-like values.

SampleBatch.set_get_interceptor

Sets a function to be called on every getitem.

SampleBatch.is_training

SampleBatch.set_training

Sets the is_training flag for this SampleBatch.

SampleBatch.as_multi_agent

Returns the respective MultiAgentBatch

SampleBatch.get

Returns one column (by key) from the data or a default value.

SampleBatch.to_device

TODO: transfer batch to given device as framework tensor.

SampleBatch.right_zero_pad

Right (adding zeros at end) zero-pads this SampleBatch in-place.

SampleBatch.slice

Returns a slice of the row data of this batch (w/o copying).

SampleBatch.split_by_episode

Splits by eps_id column and returns list of new batches.

SampleBatch.shuffle

Shuffles the rows of this batch in-place.

SampleBatch.columns

Returns a list of the batch-data in the specified columns.

SampleBatch.rows

Returns an iterator over data rows, i.e. dicts with column values.

SampleBatch.copy

Creates a deep or shallow copy of this SampleBatch and returns it.

SampleBatch.is_single_trajectory

Returns True if this SampleBatch only contains one trajectory.

SampleBatch.is_terminated_or_truncated

Returns True if self is either terminated or truncated at idx -1.

SampleBatch.env_steps

Returns the same as len(self) (number of steps in this batch).

SampleBatch.agent_steps

Returns the same as len(self) (number of steps in this batch).

MultiAgent batch API#

MultiAgentBatch

A batch of experiences from multiple agents in the environment.

MultiAgentBatch.env_steps

The number of env steps (there are >= 1 agent steps per env step).

MultiAgentBatch.agent_steps

The number of agent steps (there are >= 1 agent steps per env step).