LearnerGroup.update_from_batch(batch: MultiAgentBatch, *, timesteps: Dict[str, Any] | None = None, async_update: bool = False, return_state: bool = False, minibatch_size: int | None = None, num_iters: int = 1, reduce_fn=-1, **kwargs) Dict[str, Any] | List[Dict[str, Any]] | List[List[Dict[str, Any]]][source]#

Performs gradient based update(s) on the Learner(s), based on given batch.

  • batch – A data batch to use for the update. If there are more than one Learner workers, the batch is split amongst these and one shard is sent to each Learner.

  • async_update – Whether the update request(s) to the Learner workers should be sent asynchronously. If True, will return NOT the results from the update on the given data, but all results from prior asynchronous update requests that have not been returned thus far.

  • return_state – Whether to include one of the Learner worker’s state from after the update step in the returned results dict (under the _rl_module_state_after_update key). Note that after an update, all Learner workers’ states should be identical, so we use the first Learner’s state here. Useful for avoiding an extra get_weights() call, e.g. for synchronizing EnvRunner weights.

  • minibatch_size – The minibatch size to use for the update.

  • num_iters – The number of complete passes over all the sub-batches in the input multi-agent batch.


If async_update is False, a dictionary with the reduced results of the updates from the Learner(s) or a list of dictionaries of results from the updates from the Learner(s). If async_update is True, a list of list of dictionaries of results, where the outer list corresponds to separate previous calls to this method, and the inner list corresponds to the results from each Learner(s). Or if the results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready.