ray.rllib.policy.Policy.compute_gradients
ray.rllib.policy.Policy.compute_gradients#
- Policy.compute_gradients(postprocessed_batch: ray.rllib.policy.sample_batch.SampleBatch) Tuple[Union[List[Tuple[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor], Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]], List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]], Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]] [source]#
Computes gradients given a batch of experiences.
Either this in combination with
apply_gradients()
orlearn_on_batch()
must be implemented by subclasses.- Parameters
postprocessed_batch – The SampleBatch object to use for calculating gradients.
- Returns
List of gradient output values. grad_info: Extra policy-specific info values.
- Return type
grads