ray.rllib.evaluation.rollout_worker.RolloutWorker.compute_gradients#
- RolloutWorker.compute_gradients(samples: SampleBatch | MultiAgentBatch | Dict[str, Any], single_agent: bool = None) Tuple[List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], dict] [source]#
Returns a gradient computed w.r.t the specified samples.
Uses the Policy’s/ies’ compute_gradients method(s) to perform the calculations. Skips policies that are not trainable as per
self.is_policy_to_train()
.- Parameters:
samples – The SampleBatch or MultiAgentBatch to compute gradients for using this worker’s trainable policies.
- Returns:
In the single-agent case, a tuple consisting of ModelGradients and info dict of the worker’s policy. In the multi-agent case, a tuple consisting of a dict mapping PolicyID to ModelGradients and a dict mapping PolicyID to extra metadata info. Note that the first return value (grads) can be applied as is to a compatible worker using the worker’s
apply_gradients()
method.
import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) batch = worker.sample() grads, info = worker.compute_gradients(samples)