Sampling the Environment or offline data#

Data ingest via either environment rollouts or other data-generating methods (e.g. reading from offline files) is done in RLlib by RolloutWorker instances, which sit inside a WorkerSet (together with other parallel RolloutWorkers) in the RLlib Algorithm (under the self.workers property):

../../_images/rollout_worker_class_overview.svg

A typical RLlib WorkerSet setup inside an RLlib Algorithm: Each WorkerSet contains exactly one local RolloutWorker object and N ray remote RolloutWorker (ray actors). The workers contain a policy map (with one or more policies), and - in case a simulator (env) is available - a vectorized BaseEnv (containing M sub-environments) and a SamplerInput (either synchronous or asynchronous) which controls the environment data collection loop. In the online case (i.e. environment is available) as well as the offline case (i.e. no environment), Algorithm uses the sample() method to get SampleBatch objects for training.#

RolloutWorker API#

Constructor#

RolloutWorker

Common experience collection class.

Multi agent#

add_policy

Adds a new policy to this RolloutWorker.

remove_policy

Removes a policy from this RolloutWorker.

get_policy

Return policy for the specified id, or None.

set_is_policy_to_train

Sets self.is_policy_to_train() to a new callable.

set_policy_mapping_fn

Sets self.policy_mapping_fn to a new callable (if provided).

for_policy

Calls the given function with the specified policy as first arg.

foreach_policy

Calls the given function with each (policy, policy_id) tuple.

foreach_policy_to_train

Calls the given function with each (policy, policy_id) tuple.

Setter and getter methods#

get_filters

Returns a snapshot of filters.

get_global_vars

Returns the current self.global_vars dict of this RolloutWorker.

set_global_vars

Updates this worker's and all its policies' global vars.

get_host

Returns the hostname of the process running this evaluator.

get_metrics

Returns the thus-far collected metrics from this worker's rollouts.

get_node_ip

Returns the IP address of the node that this worker runs on.

get_weights

Returns each policies' model weights of this worker.

set_weights

Sets each policies' model weights of this worker.

get_state

Returns this EnvRunner's (possibly serialized) current state as a dict.

set_state

Restores this EnvRunner's state from the given state dict.

Threading#

lock

Locks this RolloutWorker via its own threading.Lock.

unlock

Unlocks this RolloutWorker via its own threading.Lock.

Sampling API#

sample

Returns a batch of experience sampled from this worker.

sample_with_count

Same as sample() but returns the count as a separate value.

sample_and_learn

Sample and batch and learn on it.

Training API#

learn_on_batch

Update policies based on the given batch.

setup_torch_data_parallel

Join a torch process group for distributed SGD.

compute_gradients

Returns a gradient computed w.r.t the specified samples.

apply_gradients

Applies the given gradients to this worker's models.

Environment API#

foreach_env

Calls the given function with each sub-environment as arg.

foreach_env_with_context

Calls given function with each sub-env plus env_ctx as args.

Miscellaneous#

stop

Releases all resources used by this RolloutWorker.

apply

Calls the given function with this Actor instance.

sync_filters

Changes self's filter to given and rebases any accumulated delta.

find_free_port

Finds a free port on the node that this worker runs on.

creation_args

Returns the kwargs dict used to create this worker.

assert_healthy

Checks that self.__init__() has been completed properly.

WorkerSet API#

Constructor#

WorkerSet

Set of EnvRunners with n @ray.remote workers and zero or one local worker.

WorkerSet.stop

Calls stop on all rollout workers (including the local one).

WorkerSet.reset

Hard overrides the remote workers in this set with the given one.

Worker Orchestration#

add_workers

Creates and adds a number of remote workers to this worker set.

foreach_worker

Calls the given function with each EnvRunner as its argument.

foreach_worker_with_id

Calls the given function with each EnvRunner and its ID as its arguments.

foreach_worker_async

Calls the given function asynchronously with each worker as the argument.

fetch_ready_async_reqs

Get esults from outstanding asynchronous requests that are ready.

num_in_flight_async_reqs

Returns the number of in-flight async requests.

local_worker

Returns the local rollout worker.

remote_workers

num_healthy_remote_workers

Returns the number of healthy remote workers.

num_healthy_workers

Returns the number of all healthy workers, including the local worker.

num_remote_worker_restarts

Total number of times managed remote workers have been restarted.

probe_unhealthy_workers

Checks for unhealthy workers and tries restoring their states.

Pass-through methods#

add_policy

Adds a policy to this WorkerSet's workers or a specific list of workers.

foreach_env

Calls func with all workers' sub-environments as args.

foreach_env_with_context

Calls func with all workers' sub-environments and env_ctx as args.

foreach_policy

Calls func with each worker's (policy, PolicyID) tuple.

foreach_policy_to_train

Apply func to all workers' Policies iff in policies_to_train.

sync_weights

Syncs model weights from the given weight source to all remote workers.

Sampler API#

InputReader instances are used to collect and return experiences from the envs. For more details on InputReader used for offline RL (e.g. reading files of pre-recorded data), see the offline RL API reference here.

Input Reader API#

InputReader

API for collecting and returning experiences during policy evaluation.

InputReader.next

Returns the next batch of read experiences.

Input Sampler API#

SamplerInput

Reads input experiences from an existing sampler.

SamplerInput.get_data

Called by self.next() to return the next batch of data.

SamplerInput.get_extra_batches

Returns list of extra batches since the last call to this method.

SamplerInput.get_metrics

Returns list of episode metrics since the last call to this method.

Synchronous Sampler API#

SyncSampler

Sync SamplerInput that collects experiences when get_data() is called.

Offline Sampler API#

The InputReader API is used by an individual RolloutWorker to produce batches of experiences either from an simulator or from an offline source (e.g. a file).

Here are some example extentions of the InputReader API:

JSON reader API#

JsonReader

Reader object that loads experiences from JSON file chunks.

JsonReader.read_all_files

Reads through all files and yields one SampleBatchType per line.

Mixed input reader#

MixedInput

Mixes input from a number of other input sources.

D4RL reader#

D4RLReader

Reader object that loads the dataset from the D4RL dataset.

IOContext#

IOContext

Class containing attributes to pass to input/output class constructors.

IOContext.default_sampler_input

Returns the RolloutWorker's SamplerInput object, if any.

Policy Map API#

PolicyMap

Maps policy IDs to Policy objects.

PolicyMap.items

Iterates over all policies, even the stashed ones.

PolicyMap.keys

Returns all valid keys, even the stashed ones.

PolicyMap.values

Returns all valid values, even the stashed ones.

Sample batch API#

SampleBatch

Wrapper around a dictionary with string keys and array-like values.

SampleBatch.set_get_interceptor

Sets a function to be called on every getitem.

SampleBatch.is_training

SampleBatch.set_training

Sets the is_training flag for this SampleBatch.

SampleBatch.as_multi_agent

Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.

SampleBatch.get

Returns one column (by key) from the data or a default value.

SampleBatch.to_device

TODO: transfer batch to given device as framework tensor.

SampleBatch.right_zero_pad

Right (adding zeros at end) zero-pads this SampleBatch in-place.

SampleBatch.slice

Returns a slice of the row data of this batch (w/o copying).

SampleBatch.split_by_episode

Splits by eps_id column and returns list of new batches.

SampleBatch.shuffle

Shuffles the rows of this batch in-place.

SampleBatch.columns

Returns a list of the batch-data in the specified columns.

SampleBatch.rows

Returns an iterator over data rows, i.e. dicts with column values.

SampleBatch.copy

Creates a deep or shallow copy of this SampleBatch and returns it.

SampleBatch.is_single_trajectory

Returns True if this SampleBatch only contains one trajectory.

SampleBatch.is_terminated_or_truncated

Returns True if self is either terminated or truncated at idx -1.

SampleBatch.env_steps

Returns the same as len(self) (number of steps in this batch).

SampleBatch.agent_steps

Returns the same as len(self) (number of steps in this batch).

MultiAgent batch API#

MultiAgentBatch

A batch of experiences from multiple agents in the environment.

MultiAgentBatch.env_steps

The number of env steps (there are >= 1 agent steps per env step).

MultiAgentBatch.agent_steps

The number of agent steps (there are >= 1 agent steps per env step).