ray.rllib.core.learner.learner_group.LearnerGroup.update#
- LearnerGroup.update(*, batch: MultiAgentBatch | None = None, batches: List[MultiAgentBatch] | None = None, batch_refs: List[ray._raylet.ObjectRef] | None = None, episodes: List[SingleAgentEpisode | MultiAgentEpisode] | None = None, episodes_refs: List[ray._raylet.ObjectRef] | None = None, data_iterators: List[DataIterator] | None = None, training_data: TrainingData | None = None, timesteps: Dict[str, Any] | None = None, async_update: bool = False, return_state: bool = False, **kwargs) List[Dict[str, Any]][source]#
Performs gradient based updates on Learners in parallel.
Updates are performed with data from any of the provided arguments (batch, batches, batch_refs, episodes, episodes_refs, data_iterators, training_data).
- Parameters:
batch – A data batch to use for the update. If there are more than one Learner workers, the batch is split amongst these and one shard is sent to each Learner.
batch_refs – A list of Ray ObjectRefs to the batches. If there are more than one Learner workers, the list of batch refs is split amongst these and one list shard is sent to each Learner.
episodes – A list of Episodes to process and perform the update for. If there are more than one Learner workers, the list of episodes is split amongst these and one list shard is sent to each Learner.
episodes_refs – A list of Ray ObjectRefs to the episodes. If there are more than one Learner workers, the list of episode refs is split amongst these and one list shard is sent to each Learner.
timesteps – A dictionary of timesteps to pass to the Learners’s update method. This is usually used for learning rate scheduling but can be used for any other purpose.
training_data – A TrainingData object to use for the update. If not provided, a new TrainingData object will be created from the batch, batches, batch_refs, episodes, and episodes_refs.
async_update – Whether the update request(s) to the Learner workers should be sent asynchronously. If True, will return NOT the results from the update on the given data, but all results from prior asynchronous update requests that have not been returned thus far.
return_state – Whether to include one of the Learner worker’s state from after the update step in the returned results dict (under the
_rl_module_state_after_updatekey). Note that after an update, all Learner workers’ states should be identical, so we use the first Learner’s state here. Useful for avoiding an extraget_weights()call, e.g. for synchronizing EnvRunner weights.num_epochs – The number of complete passes over the entire train batch. Each pass might be further split into n minibatches (if
minibatch_sizeprovided).minibatch_size – The size of minibatches to use to further split the train
batchinto sub-batches. Thebatchis then iterated over n times where n islen(batch) // minibatch_size.shuffle_batch_per_epoch – Whether to shuffle the train batch once per epoch. If the train batch has a time rank (axis=1), shuffling will only take place along the batch axis to not disturb any intact (episode) trajectories. Also, shuffling is always skipped if
minibatch_sizeis None, meaning the entire train batch is processed each epoch, making it unnecessary to shuffle.**kwargs
- Returns:
If
async_updateis False, a dictionary with the reduced results of the updates from the Learner(s) or a list of dictionaries of results from the updates from the Learner(s). Ifasync_updateis True, a list of list of dictionaries of results, where the outer list corresponds to separate previous calls to this method, and the inner list corresponds to the results from each Learner(s). Or if the results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready.