ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors#
- ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors(policy: TorchPolicy | TorchPolicyV2) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [source]#
Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
TD-errors are extracted from the TorchPolicy via its tower_stats property.
- Parameters:
policy – The TorchPolicy to extract the TD-error values from.
- Returns:
A dict mapping strings “td_error” and “mean_td_error” to the corresponding concatenated and mean-reduced values.