Source code for ray.rllib.utils.torch_utils

import logging
import os
import warnings
from typing import Dict, List, Optional, TYPE_CHECKING, Union

import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
from packaging import version
import tree  # pip install dm_tree

from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI, OldAPIStack
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER
from ray.rllib.utils.typing import (
    LocalOptimizer,
    NetworkType,
    SpaceStruct,
    TensorStructType,
    TensorType,
)

if TYPE_CHECKING:
    from ray.rllib.core.learner.learner import ParamDict, ParamList
    from ray.rllib.policy.torch_policy import TorchPolicy
    from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2

logger = logging.getLogger(__name__)
torch, nn = try_import_torch()

# Limit values suitable for use as close to a -inf logit. These are useful
# since -inf / inf cause NaNs during backprop.
FLOAT_MIN = -3.4e38
FLOAT_MAX = 3.4e38

if torch:
    TORCH_COMPILE_REQUIRED_VERSION = version.parse("2.0.0")
else:
    TORCH_COMPILE_REQUIRED_VERSION = ValueError(
        "torch is not installed. " "TORCH_COMPILE_REQUIRED_VERSION is " "not defined."
    )


@OldAPIStack
def apply_grad_clipping(
    policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType
) -> Dict[str, TensorType]:
    """Applies gradient clipping to already computed grads inside `optimizer`.

    Note: This function does NOT perform an analogous operation as
    tf.clip_by_global_norm. It merely clips by norm (per gradient tensor) and
    then computes the global norm across all given tensors (but without clipping
    by that global norm).

    Args:
        policy: The TorchPolicy, which calculated `loss`.
        optimizer: A local torch optimizer object.
        loss: The torch loss tensor.

    Returns:
        An info dict containing the "grad_norm" key and the resulting clipped
        gradients.
    """
    grad_gnorm = 0
    if policy.config["grad_clip"] is not None:
        clip_value = policy.config["grad_clip"]
    else:
        clip_value = np.inf

    num_none_grads = 0
    for param_group in optimizer.param_groups:
        # Make sure we only pass params with grad != None into torch
        # clip_grad_norm_. Would fail otherwise.
        params = list(filter(lambda p: p.grad is not None, param_group["params"]))
        if params:
            # PyTorch clips gradients inplace and returns the norm before clipping
            # We therefore need to compute grad_gnorm further down (fixes #4965)
            global_norm = nn.utils.clip_grad_norm_(params, clip_value)

            if isinstance(global_norm, torch.Tensor):
                global_norm = global_norm.cpu().numpy()

            grad_gnorm += min(global_norm, clip_value)
        else:
            num_none_grads += 1

    # Note (Kourosh): grads could indeed be zero. This method should still return
    # grad_gnorm in that case.
    if num_none_grads == len(optimizer.param_groups):
        # No grads available
        return {}
    return {"grad_gnorm": grad_gnorm}


[docs] @PublicAPI def clip_gradients( gradients_dict: "ParamDict", *, grad_clip: Optional[float] = None, grad_clip_by: str = "value", ) -> TensorType: """Performs gradient clipping on a grad-dict based on a clip value and clip mode. Changes the provided gradient dict in place. Args: gradients_dict: The gradients dict, mapping str to gradient tensors. grad_clip: The value to clip with. The way gradients are clipped is defined by the `grad_clip_by` arg (see below). grad_clip_by: One of 'value', 'norm', or 'global_norm'. Returns: If `grad_clip_by`="global_norm" and `grad_clip` is not None, returns the global norm of all tensors, otherwise returns None. """ # No clipping, return. if grad_clip is None: return # Clip by value (each gradient individually). if grad_clip_by == "value": for k, v in gradients_dict.copy().items(): gradients_dict[k] = ( None if v is None else torch.clip(v, -grad_clip, grad_clip) ) # Clip by L2-norm (per gradient tensor). elif grad_clip_by == "norm": for k, v in gradients_dict.copy().items(): if v is not None: # Compute the L2-norm of the gradient tensor. norm = v.norm(2).nan_to_num(neginf=-10e8, posinf=10e8) # Clip all the gradients. if norm > grad_clip: v.mul_(grad_clip / norm) # Clip by global L2-norm (across all gradient tensors). else: assert ( grad_clip_by == "global_norm" ), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!" gradients_list = list(gradients_dict.values()) total_norm = compute_global_norm(gradients_list) if len(gradients_list) == 0: return total_norm # We do want the coefficient to be in between 0.0 and 1.0, therefore # if the global_norm is smaller than the clip value, we use the clip value # as normalization constant. device = gradients_list[0].device clip_coef = grad_clip / torch.maximum( torch.tensor(grad_clip).to(device), total_norm + 1e-6 ) # Note: multiplying by the clamped coef is redundant when the coef is clamped to # 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a # CPU <=> device synchronization when the gradients do not reside in CPU memory. clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for g in gradients_list: if g is not None: g.detach().mul_(clip_coef_clamped.to(g.device)) return total_norm
[docs] @PublicAPI def compute_global_norm(gradients_list: "ParamList") -> TensorType: """Computes the global norm for a gradients dict. Args: gradients_list: The gradients list containing parameters. Returns: Returns the global norm of all tensors in `gradients_list`. """ # Define the norm type to be L2. norm_type = 2.0 # If we have no grads, return zero. if len(gradients_list) == 0: return torch.tensor(0.0) device = gradients_list[0].device # Compute the global norm. total_norm = torch.norm( torch.stack( [ torch.norm(g.detach(), norm_type) # Note, we want to avoid overflow in the norm computation, this does # not affect the gradients themselves as we clamp by multiplying and # not by overriding tensor values. .nan_to_num(neginf=-10e8, posinf=10e8).to(device) for g in gradients_list if g is not None ] ), norm_type, ).nan_to_num(neginf=-10e8, posinf=10e8) if torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError( f"The total norm of order {norm_type} for gradients from " "`parameters` is non-finite, so it cannot be clipped. " ) # Return the global norm. return total_norm
@OldAPIStack def concat_multi_gpu_td_errors( policy: Union["TorchPolicy", "TorchPolicyV2"] ) -> Dict[str, TensorType]: """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy. TD-errors are extracted from the TorchPolicy via its tower_stats property. Args: policy: The TorchPolicy to extract the TD-error values from. Returns: A dict mapping strings "td_error" and "mean_td_error" to the corresponding concatenated and mean-reduced values. """ td_error = torch.cat( [ t.tower_stats.get("td_error", torch.tensor([0.0])).to(policy.device) for t in policy.model_gpu_towers ], dim=0, ) policy.td_error = td_error return { "td_error": td_error, "mean_td_error": torch.mean(td_error), }
[docs] @PublicAPI def convert_to_torch_tensor( x: TensorStructType, device: Optional[str] = None, pin_memory: bool = False, ): """Converts any struct to torch.Tensors. Args: x: Any (possibly nested) struct, the values in which will be converted and returned as a new struct with all leaves converted to torch tensors. device: The device to create the tensor on. pin_memory: If True, will call the `pin_memory()` method on the created tensors. Returns: Any: A new struct with the same structure as `x`, but with all values converted to torch Tensor types. This does not convert possibly nested elements that are None because torch has no representation for that. """ def mapping(item): if item is None: # Torch has no representation for `None`, so we return None return item # Special handling of "Repeated" values. if isinstance(item, RepeatedValues): return RepeatedValues( tree.map_structure(mapping, item.values), item.lengths, item.max_len ) # Already torch tensor -> make sure it's on right device. if torch.is_tensor(item): tensor = item # Numpy arrays. elif isinstance(item, np.ndarray): # Object type (e.g. info dicts in train batch): leave as-is. # str type (e.g. agent_id in train batch): leave as-is. if item.dtype == object or item.dtype.type is np.str_: return item # Non-writable numpy-arrays will cause PyTorch warning. elif item.flags.writeable is False: with warnings.catch_warnings(): warnings.simplefilter("ignore") tensor = torch.from_numpy(item) # Already numpy: Wrap as torch tensor. else: tensor = torch.from_numpy(item) # Everything else: Convert to numpy, then wrap as torch tensor. else: tensor = torch.from_numpy(np.asarray(item)) # Floatify all float64 tensors (but leave float16 as-is). if tensor.is_floating_point() and str(tensor.dtype) != "torch.float16": tensor = tensor.float() # Pin the tensor's memory (for faster transfer to GPU later). if pin_memory and torch.cuda.is_available(): tensor.pin_memory() return tensor if device is None else tensor.to(device) return tree.map_structure(mapping, x)
@PublicAPI def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None): """Creates a copy of `x` and makes deep copies torch.Tensors in x. Also moves the copied tensors to the specified device (if not None). Note if an object in x is not a torch.Tensor, it will be shallow-copied. Args: x : Any (possibly nested) struct possibly containing torch.Tensors. device : The device to move the tensors to. Returns: Any: A new struct with the same structure as `x`, but with all torch.Tensors deep-copied and moved to the specified device. """ def mapping(item): if isinstance(item, torch.Tensor): return ( torch.clone(item.detach()) if device is None else item.detach().to(device) ) else: return item return tree.map_structure(mapping, x)
[docs] @PublicAPI def explained_variance(y: TensorType, pred: TensorType) -> TensorType: """Computes the explained variance for a pair of labels and predictions. The formula used is: max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2)) Args: y: The labels. pred: The predictions. Returns: The explained variance given a pair of labels and predictions. """ y_var = torch.var(y, dim=[0]) diff_var = torch.var(y - pred, dim=[0]) min_ = torch.tensor([-1.0]).to(pred.device) return torch.max(min_, 1 - (diff_var / (y_var + SMALL_NUMBER)))[0]
[docs] @PublicAPI def flatten_inputs_to_1d_tensor( inputs: TensorStructType, spaces_struct: Optional[SpaceStruct] = None, time_axis: bool = False, ) -> TensorType: """Flattens arbitrary input structs according to the given spaces struct. Returns a single 1D tensor resulting from the different input components' values. Thereby: - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes are not treated differently from other types of Boxes and get flattened as well. - Discrete (int) values are one-hot'd, e.g. a batch of [1, 0, 3] (B=3 with Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]. - MultiDiscrete values are multi-one-hot'd, e.g. a batch of [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]]. Args: inputs: The inputs to be flattened. spaces_struct: The structure of the spaces that behind the input time_axis: Whether all inputs have a time-axis (after the batch axis). If True, will keep not only the batch axis (0th), but the time axis (1st) as-is and flatten everything from the 2nd axis up. Returns: A single 1D tensor resulting from concatenating all flattened/one-hot'd input components. Depending on the time_axis flag, the shape is (B, n) or (B, T, n). .. testcode:: from gymnasium.spaces import Discrete, Box from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor import torch struct = { "a": np.array([1, 3]), "b": ( np.array([[1.0, 2.0], [4.0, 5.0]]), np.array( [[[8.0], [7.0]], [[5.0], [4.0]]] ), ), "c": { "cb": np.array([1.0, 2.0]), }, } struct_torch = tree.map_structure(lambda s: torch.from_numpy(s), struct) spaces = dict( { "a": gym.spaces.Discrete(4), "b": (gym.spaces.Box(-1.0, 10.0, (2,)), gym.spaces.Box(-1.0, 1.0, (2, 1))), "c": dict( { "cb": gym.spaces.Box(-1.0, 1.0, ()), } ), } ) print(flatten_inputs_to_1d_tensor(struct_torch, spaces_struct=spaces)) .. testoutput:: tensor([[0., 1., 0., 0., 1., 2., 8., 7., 1.], [0., 0., 0., 1., 4., 5., 5., 4., 2.]]) """ flat_inputs = tree.flatten(inputs) flat_spaces = ( tree.flatten(spaces_struct) if spaces_struct is not None else [None] * len(flat_inputs) ) B = None T = None out = [] for input_, space in zip(flat_inputs, flat_spaces): # Store batch and (if applicable) time dimension. if B is None: B = input_.shape[0] if time_axis: T = input_.shape[1] # One-hot encoding. if isinstance(space, Discrete): if time_axis: input_ = torch.reshape(input_, [B * T]) out.append(one_hot(input_, space).float()) # Multi one-hot encoding. elif isinstance(space, MultiDiscrete): if time_axis: input_ = torch.reshape(input_, [B * T, -1]) out.append(one_hot(input_, space).float()) # Box: Flatten. else: if time_axis: input_ = torch.reshape(input_, [B * T, -1]) else: input_ = torch.reshape(input_, [B, -1]) out.append(input_.float()) merged = torch.cat(out, dim=-1) # Restore the time-dimension, if applicable. if time_axis: merged = torch.reshape(merged, [B, T, -1]) return merged
[docs] @PublicAPI def global_norm(tensors: List[TensorType]) -> TensorType: """Returns the global L2 norm over a list of tensors. output = sqrt(SUM(t ** 2 for t in tensors)), where SUM reduces over all tensors and over all elements in tensors. Args: tensors: The list of tensors to calculate the global norm over. Returns: The global L2 norm over the given tensor list. """ # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor. single_l2s = [torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors] # Compute global norm from all single tensors' L2 norms. return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5)
@OldAPIStack def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType: """Computes the huber loss for a given term and delta parameter. Reference: https://en.wikipedia.org/wiki/Huber_loss Note that the factor of 0.5 is implicitly included in the calculation. Formula: L = 0.5 * x^2 for small abs x (delta threshold) L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold) Args: x: The input term, e.g. a TD error. delta: The delta parmameter in the above formula. Returns: The Huber loss resulting from `x` and `delta`. """ return torch.where( torch.abs(x) < delta, torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta), ) @OldAPIStack def l2_loss(x: TensorType) -> TensorType: """Computes half the L2 norm over a tensor's values without the sqrt. output = 0.5 * sum(x ** 2) Args: x: The input tensor. Returns: 0.5 times the L2 norm over the given tensor's values (w/o sqrt). """ return 0.5 * torch.sum(torch.pow(x, 2.0))
[docs] @PublicAPI def one_hot(x: TensorType, space: gym.Space) -> TensorType: """Returns a one-hot tensor, given and int tensor and a space. Handles the MultiDiscrete case as well. Args: x: The input tensor. space: The space to use for generating the one-hot tensor. Returns: The resulting one-hot tensor. Raises: ValueError: If the given space is not a discrete one. .. testcode:: import torch import gymnasium as gym from ray.rllib.utils.torch_utils import one_hot x = torch.IntTensor([0, 3]) # batch-dim=2 # Discrete space with 4 (one-hot) slots per batch item. s = gym.spaces.Discrete(4) print(one_hot(x, s)) x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1 # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots # per batch item. s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) print(one_hot(x, s)) .. testoutput:: tensor([[1, 0, 0, 0], [0, 0, 0, 1]]) tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]]) """ if isinstance(space, Discrete): return nn.functional.one_hot(x.long(), space.n) elif isinstance(space, MultiDiscrete): if isinstance(space.nvec[0], np.ndarray): nvec = np.ravel(space.nvec) x = x.reshape(x.shape[0], -1) else: nvec = space.nvec return torch.cat( [nn.functional.one_hot(x[:, i].long(), n) for i, n in enumerate(nvec)], dim=-1, ) else: raise ValueError("Unsupported space for `one_hot`: {}".format(space))
[docs] @PublicAPI def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType: """Same as torch.mean() but ignores -inf values. Args: x: The input tensor to reduce mean over. axis: The axis over which to reduce. None for all axes. Returns: The mean reduced inputs, ignoring inf values. """ mask = torch.ne(x, float("-inf")) x_zeroed = torch.where(mask, x, torch.zeros_like(x)) return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
[docs] @PublicAPI def sequence_mask( lengths: TensorType, maxlen: Optional[int] = None, dtype=None, time_major: bool = False, ) -> TensorType: """Offers same behavior as tf.sequence_mask for torch. Thanks to Dimitris Papatheodorou (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ 39036). Args: lengths: The tensor of individual lengths to mask by. maxlen: The maximum length to use for the time axis. If None, use the max of `lengths`. dtype: The torch dtype to use for the resulting mask. time_major: Whether to return the mask as [B, T] (False; default) or as [T, B] (True). Returns: The sequence mask resulting from the given input and parameters. """ # If maxlen not given, use the longest lengths in the `lengths` tensor. if maxlen is None: maxlen = lengths.max() mask = torch.ones(tuple(lengths.shape) + (int(maxlen),)) mask = ~(mask.to(lengths.device).cumsum(dim=1).t() > lengths) # Time major transformation. if not time_major: mask = mask.t() # By default, set the mask to be boolean. mask.type(dtype or torch.bool) return mask
[docs] @PublicAPI def update_target_network( main_net: NetworkType, target_net: NetworkType, tau: float, ) -> None: """Updates a torch.nn.Module target network using Polyak averaging. .. code-block:: text new_target_net_weight = ( tau * main_net_weight + (1.0 - tau) * current_target_net_weight ) Args: main_net: The nn.Module to update from. target_net: The target network to update. tau: The tau value to use in the Polyak averaging formula. """ # Get the current parameters from the Q network. state_dict = main_net.state_dict() # Use here Polyak averaging. new_state_dict = { k: tau * state_dict[k] + (1 - tau) * v for k, v in target_net.state_dict().items() } # Apply the new parameters to the target Q network. target_net.load_state_dict(new_state_dict)
@DeveloperAPI def warn_if_infinite_kl_divergence( policy: "TorchPolicy", kl_divergence: TensorType, ) -> None: if policy.loss_initialized() and kl_divergence.isinf(): logger.warning( "KL divergence is non-finite, this will likely destabilize your model and" " the training process. Action(s) in a specific state have near-zero" " probability. This can happen naturally in deterministic environments" " where the optimal policy has zero mass for a specific action. To fix this" " issue, consider setting the coefficient for the KL loss term to zero or" " increasing policy entropy." )
[docs] @PublicAPI def set_torch_seed(seed: Optional[int] = None) -> None: """Sets the torch random seed to the given value. Args: seed: The seed to use or None for no seeding. """ if seed is not None and torch: torch.manual_seed(seed) # See https://github.com/pytorch/pytorch/issues/47672. cuda_version = torch.version.cuda if cuda_version is not None and float(torch.version.cuda) >= 10.2: os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8" else: # Not all Operations support this. torch.use_deterministic_algorithms(True) # This is only for Convolution no problem. torch.backends.cudnn.deterministic = True
[docs] @PublicAPI def softmax_cross_entropy_with_logits( logits: TensorType, labels: TensorType, ) -> TensorType: """Same behavior as tf.nn.softmax_cross_entropy_with_logits. Args: x: The input predictions. labels: The labels corresponding to `x`. Returns: The resulting softmax cross-entropy given predictions and labels. """ return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1)
def _dynamo_is_available(): # This only works if torch._dynamo is available try: # TODO(Artur): Remove this once torch._dynamo is available on CI import torch._dynamo as dynamo # noqa: F401 return True except ImportError: return False