ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors#

ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors(policy: TorchPolicy | TorchPolicyV2) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.

TD-errors are extracted from the TorchPolicy via its tower_stats property.

Parameters:

policy – The TorchPolicy to extract the TD-error values from.

Returns:

A dict mapping strings “td_error” and “mean_td_error” to the corresponding concatenated and mean-reduced values.