ray.rllib.evaluation.rollout_worker.RolloutWorker.compute_gradients
ray.rllib.evaluation.rollout_worker.RolloutWorker.compute_gradients#
- RolloutWorker.compute_gradients(samples: Union[SampleBatch, MultiAgentBatch], single_agent: bool = None) Tuple[Union[List[Tuple[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor], Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]], List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]], dict] [source]#
Returns a gradient computed w.r.t the specified samples.
Uses the Policy’s/ies’ compute_gradients method(s) to perform the calculations. Skips policies that are not trainable as per
self.is_policy_to_train()
.- Parameters
samples – The SampleBatch or MultiAgentBatch to compute gradients for using this worker’s trainable policies.
- Returns
In the single-agent case, a tuple consisting of ModelGradients and info dict of the worker’s policy. In the multi-agent case, a tuple consisting of a dict mapping PolicyID to ModelGradients and a dict mapping PolicyID to extra metadata info. Note that the first return value (grads) can be applied as is to a compatible worker using the worker’s
apply_gradients()
method.
Examples
>>> import gymnasium as gym >>> from ray.rllib.evaluation.rollout_worker import RolloutWorker >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( ... env_creator=lambda _: gym.make("CartPole-v1"), ... default_policy_class=PGTF1Policy) >>> batch = worker.sample() >>> grads, info = worker.compute_gradients(samples)