ray.rllib.core.learner.learner_group.LearnerGroup.update_from_batch#

LearnerGroup.update_from_batch(batch: MultiAgentBatch, *, timesteps: Dict[str, Any] | None = None, async_update: bool = False, return_state: bool = False, num_epochs: int = 1, minibatch_size: int | None = None, shuffle_batch_per_epoch: bool = False, **kwargs) Dict[str, Any] | List[Dict[str, Any]] | List[List[Dict[str, Any]]][source]#

Performs gradient based update(s) on the Learner(s), based on given batch.

Parameters:
  • batch – A data batch to use for the update. If there are more than one Learner workers, the batch is split amongst these and one shard is sent to each Learner.

  • async_update – Whether the update request(s) to the Learner workers should be sent asynchronously. If True, will return NOT the results from the update on the given data, but all results from prior asynchronous update requests that have not been returned thus far.

  • return_state – Whether to include one of the Learner worker’s state from after the update step in the returned results dict (under the _rl_module_state_after_update key). Note that after an update, all Learner workers’ states should be identical, so we use the first Learner’s state here. Useful for avoiding an extra get_weights() call, e.g. for synchronizing EnvRunner weights.

  • num_epochs – The number of complete passes over the entire train batch. Each pass might be further split into n minibatches (if minibatch_size provided).

  • minibatch_size – The size of minibatches to use to further split the train batch into sub-batches. The batch is then iterated over n times where n is len(batch) // minibatch_size.

  • shuffle_batch_per_epoch – Whether to shuffle the train batch once per epoch. If the train batch has a time rank (axis=1), shuffling will only take place along the batch axis to not disturb any intact (episode) trajectories. Also, shuffling is always skipped if minibatch_size is None, meaning the entire train batch is processed each epoch, making it unnecessary to shuffle.

Returns:

If async_update is False, a dictionary with the reduced results of the updates from the Learner(s) or a list of dictionaries of results from the updates from the Learner(s). If async_update is True, a list of list of dictionaries of results, where the outer list corresponds to separate previous calls to this method, and the inner list corresponds to the results from each Learner(s). Or if the results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready.