ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.grad_stats_fn
ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.grad_stats_fn#
- EagerTFPolicyV2.grad_stats_fn(train_batch: ray.rllib.policy.sample_batch.SampleBatch, grads: Union[List[Tuple[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor], Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]], List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]]) Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]] [source]#
Gradient stats function. Returns a dict of statistics.
- Parameters
train_batch – The SampleBatch (already) used for training.
- Returns
The stats dict.