ray.rllib.policy.policy.Policy.compute_gradients#
- Policy.compute_gradients(postprocessed_batch: SampleBatch) Tuple[List[Tuple[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]] [source]#
Computes gradients given a batch of experiences.
Either this in combination with
apply_gradients()
orlearn_on_batch()
must be implemented by subclasses.- Parameters:
postprocessed_batch – The SampleBatch object to use for calculating gradients.
- Returns:
List of gradient output values. grad_info: Extra policy-specific info values.
- Return type:
grads