Sampling the Environment or offline data#

Data ingest via either environment rollouts or other data-generating methods (e.g. reading from offline files) is done in RLlib by RolloutWorker instances, which sit inside a WorkerSet (together with other parallel RolloutWorkers) in the RLlib Algorithm (under the self.workers property):

../../_images/rollout_worker_class_overview.svg

A typical RLlib WorkerSet setup inside an RLlib Algorithm: Each WorkerSet contains exactly one local RolloutWorker object and N ray remote RolloutWorker (ray actors). The workers contain a policy map (with one or more policies), and - in case a simulator (env) is available - a vectorized BaseEnv (containing M sub-environments) and a SamplerInput (either synchronous or asynchronous) which controls the environment data collection loop. In the online case (i.e. environment is available) as well as the offline case (i.e. no environment), Algorithm uses the sample() method to get SampleBatch objects for training.#

RolloutWorker API#

Constructor#

RolloutWorker(*, env_creator[, ...])

Common experience collection class.

Multi agent#

add_policy(policy_id[, policy_cls, policy, ...])

Adds a new policy to this RolloutWorker.

remove_policy(*[, policy_id, ...])

Removes a policy from this RolloutWorker.

get_policy([policy_id])

Return policy for the specified id, or None.

set_is_policy_to_train(is_policy_to_train)

Sets self.is_policy_to_train() to a new callable.

set_policy_mapping_fn([policy_mapping_fn])

Sets self.policy_mapping_fn to a new callable (if provided).

for_policy(func[, policy_id])

Calls the given function with the specified policy as first arg.

foreach_policy(func, **kwargs)

Calls the given function with each (policy, policy_id) tuple.

foreach_policy_to_train(func, **kwargs)

Calls the given function with each (policy, policy_id) tuple.

Setter and getter methods#

get_filters([flush_after])

Returns a snapshot of filters.

get_global_vars()

Returns the current self.global_vars dict of this RolloutWorker.

set_global_vars(global_vars[, policy_ids])

Updates this worker's and all its policies' global vars.

get_host()

Returns the hostname of the process running this evaluator.

get_metrics()

Returns the thus-far collected metrics from this worker's rollouts.

get_node_ip()

Returns the IP address of the node that this worker runs on.

get_weights([policies])

Returns each policies' model weights of this worker.

set_weights(weights[, global_vars, ...])

Sets each policies' model weights of this worker.

get_state()

Returns this EnvRunner's (possibly serialized) current state as a dict.

set_state(state)

Restores this EnvRunner's state from the given state dict.

Threading#

lock()

Locks this RolloutWorker via its own threading.Lock.

unlock()

Unlocks this RolloutWorker via its own threading.Lock.

Sampling API#

sample(**kwargs)

Returns a batch of experience sampled from this worker.

sample_with_count()

Same as sample() but returns the count as a separate value.

sample_and_learn(expected_batch_size, ...)

Sample and batch and learn on it.

Training API#

learn_on_batch(samples)

Update policies based on the given batch.

setup_torch_data_parallel(url, world_rank, ...)

Join a torch process group for distributed SGD.

compute_gradients(samples[, single_agent])

Returns a gradient computed w.r.t the specified samples.

apply_gradients(grads)

Applies the given gradients to this worker's models.

Environment API#

foreach_env(func)

Calls the given function with each sub-environment as arg.

foreach_env_with_context(func)

Calls given function with each sub-env plus env_ctx as args.

Miscellaneous#

stop()

Releases all resources used by this RolloutWorker.

apply(func, *args, **kwargs)

Calls the given function with this rollout worker instance.

sync_filters(new_filters)

Changes self's filter to given and rebases any accumulated delta.

find_free_port()

Finds a free port on the node that this worker runs on.

creation_args()

Returns the kwargs dict used to create this worker.

assert_healthy()

Checks that self.__init__() has been completed properly.

WorkerSet API#

Constructor#

WorkerSet(*[, env_creator, validate_env, ...])

Set of RolloutWorkers with n @ray.remote workers and zero or one local worker.

WorkerSet.stop()

Calls stop on all rollout workers (including the local one).

WorkerSet.reset(new_remote_workers)

Hard overrides the remote workers in this set with the given one.

Worker Orchestration#

add_workers(num_workers[, validate])

Creates and adds a number of remote workers to this worker set.

foreach_worker(func, *[, local_worker, ...])

Calls the given function with each worker instance as the argument.

foreach_worker_with_id(func, *[, ...])

Similar to foreach_worker(), but calls the function with id of the worker too.

foreach_worker_async(func, *[, ...])

Calls the given function asynchronously with each worker as the argument.

fetch_ready_async_reqs(*[, timeout_seconds, ...])

Get esults from outstanding asynchronous requests that are ready.

num_in_flight_async_reqs()

Returns the number of in-flight async requests.

local_worker()

Returns the local rollout worker.

remote_workers(**kwargs)

num_healthy_remote_workers()

Returns the number of healthy remote workers.

num_healthy_workers()

Returns the number of all healthy workers, including the local worker.

num_remote_worker_restarts()

Total number of times managed remote workers have been restarted.

probe_unhealthy_workers()

Checks the unhealth workers, and try restoring their states.

Pass-through methods#

add_policy(policy_id[, policy_cls, policy, ...])

Adds a policy to this WorkerSet's workers or a specific list of workers.

foreach_env(func)

Calls func with all workers' sub-environments as args.

foreach_env_with_context(func)

Calls func with all workers' sub-environments and env_ctx as args.

foreach_policy(func)

Calls func with each worker's (policy, PolicyID) tuple.

foreach_policy_to_train(func)

Apply func to all workers' Policies iff in policies_to_train.

sync_weights([policies, ...])

Syncs model weights from the given weight source to all remote workers.

Sampler API#

InputReader instances are used to collect and return experiences from the envs. For more details on InputReader used for offline RL (e.g. reading files of pre-recorded data), see the offline RL API reference here.

Input Reader API#

InputReader()

API for collecting and returning experiences during policy evaluation.

InputReader.next()

Returns the next batch of read experiences.

Input Sampler API#

SamplerInput()

Reads input experiences from an existing sampler.

SamplerInput.get_data()

Called by self.next() to return the next batch of data.

SamplerInput.get_extra_batches()

Returns list of extra batches since the last call to this method.

SamplerInput.get_metrics()

Returns list of episode metrics since the last call to this method.

Synchronous Sampler API#

SyncSampler(*, worker, env, clip_rewards, ...)

Sync SamplerInput that collects experiences when get_data() is called.

Asynchronous Sampler API#

AsyncSampler(*, worker, env, clip_rewards, ...)

Async SamplerInput that collects experiences in thread and queues them.

Offline Sampler API#

The InputReader API is used by an individual RolloutWorker to produce batches of experiences either from an simulator or from an offline source (e.g. a file).

Here are some example extentions of the InputReader API:

JSON reader API#

JsonReader(inputs[, ioctx])

Reader object that loads experiences from JSON file chunks.

JsonReader.read_all_files()

Reads through all files and yields one SampleBatchType per line.

Mixed input reader#

MixedInput(dist, ioctx)

Mixes input from a number of other input sources.

D4RL reader#

D4RLReader(inputs[, ioctx])

Reader object that loads the dataset from the D4RL dataset.

IOContext#

IOContext([log_dir, config, worker_index, ...])

Class containing attributes to pass to input/output class constructors.

IOContext.default_sampler_input()

Returns the RolloutWorker's SamplerInput object, if any.

Policy Map API#

PolicyMap(*[, capacity, ...])

Maps policy IDs to Policy objects.

PolicyMap.items()

Iterates over all policies, even the stashed ones.

PolicyMap.keys()

Returns all valid keys, even the stashed ones.

PolicyMap.values()

Returns all valid values, even the stashed ones.

Sample batch API#

SampleBatch(*args, **kwargs)

Wrapper around a dictionary with string keys and array-like values.

SampleBatch.set_get_interceptor(fn)

Sets a function to be called on every getitem.

SampleBatch.is_training

SampleBatch.set_training([training])

Sets the is_training flag for this SampleBatch.

SampleBatch.as_multi_agent()

Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.

SampleBatch.get(key[, default])

Returns one column (by key) from the data or a default value.

SampleBatch.to_device(device[, framework])

TODO: transfer batch to given device as framework tensor.

SampleBatch.right_zero_pad(max_seq_len[, ...])

Right (adding zeros at end) zero-pads this SampleBatch in-place.

SampleBatch.slice(start, end[, state_start, ...])

Returns a slice of the row data of this batch (w/o copying).

SampleBatch.split_by_episode([key])

Splits by eps_id column and returns list of new batches.

SampleBatch.shuffle()

Shuffles the rows of this batch in-place.

SampleBatch.columns(keys)

Returns a list of the batch-data in the specified columns.

SampleBatch.rows()

Returns an iterator over data rows, i.e. dicts with column values.

SampleBatch.copy([shallow])

Creates a deep or shallow copy of this SampleBatch and returns it.

SampleBatch.is_single_trajectory()

Returns True if this SampleBatch only contains one trajectory.

SampleBatch.is_terminated_or_truncated()

Returns True if self is either terminated or truncated at idx -1.

SampleBatch.env_steps()

Returns the same as len(self) (number of steps in this batch).

SampleBatch.agent_steps()

Returns the same as len(self) (number of steps in this batch).

MultiAgent batch API#

MultiAgentBatch(policy_batches, env_steps)

A batch of experiences from multiple agents in the environment.

MultiAgentBatch.env_steps()

The number of env steps (there are >= 1 agent steps per env step).

MultiAgentBatch.agent_steps()

The number of agent steps (there are >= 1 agent steps per env step).