import logging
import os
import warnings
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
from packaging import version
import tree # pip install dm_tree
from ray.rllib.models.repeated_values import RepeatedValues
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import SMALL_NUMBER
from ray.rllib.utils.typing import (
LocalOptimizer,
NetworkType,
SpaceStruct,
TensorStructType,
TensorType,
)
if TYPE_CHECKING:
from ray.rllib.core.learner.learner import ParamDict, ParamList
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
logger = logging.getLogger(__name__)
torch, nn = try_import_torch()
# Limit values suitable for use as close to a -inf logit. These are useful
# since -inf / inf cause NaNs during backprop.
FLOAT_MIN = -3.4e38
FLOAT_MAX = 3.4e38
if torch:
TORCH_COMPILE_REQUIRED_VERSION = version.parse("2.0.0")
else:
TORCH_COMPILE_REQUIRED_VERSION = ValueError(
"torch is not installed. " "TORCH_COMPILE_REQUIRED_VERSION is " "not defined."
)
# TODO (sven): Deprecate this function once we have moved completely to the Learner API.
# Replaced with `clip_gradients()`.
[docs]
@PublicAPI
def apply_grad_clipping(
policy: "TorchPolicy", optimizer: LocalOptimizer, loss: TensorType
) -> Dict[str, TensorType]:
"""Applies gradient clipping to already computed grads inside `optimizer`.
Note: This function does NOT perform an analogous operation as
tf.clip_by_global_norm. It merely clips by norm (per gradient tensor) and
then computes the global norm across all given tensors (but without clipping
by that global norm).
Args:
policy: The TorchPolicy, which calculated `loss`.
optimizer: A local torch optimizer object.
loss: The torch loss tensor.
Returns:
An info dict containing the "grad_norm" key and the resulting clipped
gradients.
"""
grad_gnorm = 0
if policy.config["grad_clip"] is not None:
clip_value = policy.config["grad_clip"]
else:
clip_value = np.inf
num_none_grads = 0
for param_group in optimizer.param_groups:
# Make sure we only pass params with grad != None into torch
# clip_grad_norm_. Would fail otherwise.
params = list(filter(lambda p: p.grad is not None, param_group["params"]))
if params:
# PyTorch clips gradients inplace and returns the norm before clipping
# We therefore need to compute grad_gnorm further down (fixes #4965)
global_norm = nn.utils.clip_grad_norm_(params, clip_value)
if isinstance(global_norm, torch.Tensor):
global_norm = global_norm.cpu().numpy()
grad_gnorm += min(global_norm, clip_value)
else:
num_none_grads += 1
# Note (Kourosh): grads could indeed be zero. This method should still return
# grad_gnorm in that case.
if num_none_grads == len(optimizer.param_groups):
# No grads available
return {}
return {"grad_gnorm": grad_gnorm}
@PublicAPI
def clip_gradients(
gradients_dict: "ParamDict",
*,
grad_clip: Optional[float] = None,
grad_clip_by: str = "value",
) -> TensorType:
"""Performs gradient clipping on a grad-dict based on a clip value and clip mode.
Changes the provided gradient dict in place.
Args:
gradients_dict: The gradients dict, mapping str to gradient tensors.
grad_clip: The value to clip with. The way gradients are clipped is defined
by the `grad_clip_by` arg (see below).
grad_clip_by: One of 'value', 'norm', or 'global_norm'.
Returns:
If `grad_clip_by`="global_norm" and `grad_clip` is not None, returns the global
norm of all tensors, otherwise returns None.
"""
# No clipping, return.
if grad_clip is None:
return
# Clip by value (each gradient individually).
if grad_clip_by == "value":
for k, v in gradients_dict.copy().items():
gradients_dict[k] = (
None if v is None else torch.clip(v, -grad_clip, grad_clip)
)
# Clip by L2-norm (per gradient tensor).
elif grad_clip_by == "norm":
for k, v in gradients_dict.copy().items():
if v is not None:
# Compute the L2-norm of the gradient tensor.
norm = v.norm(2).nan_to_num(neginf=-10e8, posinf=10e8)
# Clip all the gradients.
if norm > grad_clip:
v.mul_(grad_clip / norm)
# Clip by global L2-norm (across all gradient tensors).
else:
assert (
grad_clip_by == "global_norm"
), f"`grad_clip_by` ({grad_clip_by}) must be one of [value|norm|global_norm]!"
gradients_list = list(gradients_dict.values())
total_norm = compute_global_norm(gradients_list)
# We do want the coefficient to be in between 0.0 and 1.0, therefore
# if the global_norm is smaller than the clip value, we use the clip value
# as normalization constant.
device = gradients_list[0].device
clip_coef = grad_clip / torch.maximum(
torch.tensor(grad_clip).to(device), total_norm + 1e-6
)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to
# 1, but doing so avoids a `if clip_coef < 1:` conditional which can require a
# CPU <=> device synchronization when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in gradients_list:
if g is not None:
g.detach().mul_(clip_coef_clamped.to(g.device))
return total_norm
@PublicAPI
def compute_global_norm(gradients_list: "ParamList") -> TensorType:
"""Computes the global norm for a gradients dict.
Args:
gradients_list: The gradients list containing parameters.
Returns:
Returns the global norm of all tensors in `gradients_list`.
"""
# Define the norm type to be L2.
norm_type = 2.0
# If we have no grads, return zero.
if len(gradients_list) == 0:
return torch.tensor(0.0)
device = gradients_list[0].device
# Compute the global norm.
total_norm = torch.norm(
torch.stack(
[
torch.norm(g.detach(), norm_type)
# Note, we want to avoid overflow in the norm computation, this does
# not affect the gradients themselves as we clamp by multiplying and
# not by overriding tensor values.
.nan_to_num(neginf=-10e8, posinf=10e8).to(device)
for g in gradients_list
if g is not None
]
),
norm_type,
).nan_to_num(neginf=-10e8, posinf=10e8)
if torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. "
)
# Return the global norm.
return total_norm
[docs]
@PublicAPI
def concat_multi_gpu_td_errors(
policy: Union["TorchPolicy", "TorchPolicyV2"]
) -> Dict[str, TensorType]:
"""Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
TD-errors are extracted from the TorchPolicy via its tower_stats property.
Args:
policy: The TorchPolicy to extract the TD-error values from.
Returns:
A dict mapping strings "td_error" and "mean_td_error" to the
corresponding concatenated and mean-reduced values.
"""
td_error = torch.cat(
[
t.tower_stats.get("td_error", torch.tensor([0.0])).to(policy.device)
for t in policy.model_gpu_towers
],
dim=0,
)
policy.td_error = td_error
return {
"td_error": td_error,
"mean_td_error": torch.mean(td_error),
}
[docs]
@PublicAPI
def convert_to_torch_tensor(
x: TensorStructType,
device: Optional[str] = None,
pin_memory: bool = False,
):
"""Converts any struct to torch.Tensors.
Args:
x: Any (possibly nested) struct, the values in which will be
converted and returned as a new struct with all leaves converted
to torch tensors.
device: The device to create the tensor on.
pin_memory: If True, will call the `pin_memory()` method on the created tensors.
Returns:
Any: A new struct with the same structure as `x`, but with all
values converted to torch Tensor types. This does not convert possibly
nested elements that are None because torch has no representation for that.
"""
def mapping(item):
if item is None:
# Torch has no representation for `None`, so we return None
return item
# Special handling of "Repeated" values.
if isinstance(item, RepeatedValues):
return RepeatedValues(
tree.map_structure(mapping, item.values), item.lengths, item.max_len
)
# Already torch tensor -> make sure it's on right device.
if torch.is_tensor(item):
tensor = item
# Numpy arrays.
elif isinstance(item, np.ndarray):
# Object type (e.g. info dicts in train batch): leave as-is.
# str type (e.g. agent_id in train batch): leave as-is.
if item.dtype == object or item.dtype.type is np.str_:
return item
# Non-writable numpy-arrays will cause PyTorch warning.
elif item.flags.writeable is False:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tensor = torch.from_numpy(item)
# Already numpy: Wrap as torch tensor.
else:
tensor = torch.from_numpy(item)
# Everything else: Convert to numpy, then wrap as torch tensor.
else:
tensor = torch.from_numpy(np.asarray(item))
# Floatify all float64 tensors (but leave float16 as-is).
if tensor.is_floating_point() and str(tensor.dtype) != "torch.float16":
tensor = tensor.float()
# Pin the tensor's memory (for faster transfer to GPU later).
if pin_memory and torch.cuda.is_available():
tensor.pin_memory()
return tensor if device is None else tensor.to(device)
return tree.map_structure(mapping, x)
@PublicAPI
def copy_torch_tensors(x: TensorStructType, device: Optional[str] = None):
"""Creates a copy of `x` and makes deep copies torch.Tensors in x.
Also moves the copied tensors to the specified device (if not None).
Note if an object in x is not a torch.Tensor, it will be shallow-copied.
Args:
x : Any (possibly nested) struct possibly containing torch.Tensors.
device : The device to move the tensors to.
Returns:
Any: A new struct with the same structure as `x`, but with all
torch.Tensors deep-copied and moved to the specified device.
"""
def mapping(item):
if isinstance(item, torch.Tensor):
return (
torch.clone(item.detach())
if device is None
else item.detach().to(device)
)
else:
return item
return tree.map_structure(mapping, x)
[docs]
@PublicAPI
def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
"""Computes the explained variance for a pair of labels and predictions.
The formula used is:
max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
Args:
y: The labels.
pred: The predictions.
Returns:
The explained variance given a pair of labels and predictions.
"""
y_var = torch.var(y, dim=[0])
diff_var = torch.var(y - pred, dim=[0])
min_ = torch.tensor([-1.0]).to(pred.device)
return torch.max(min_, 1 - (diff_var / (y_var + SMALL_NUMBER)))[0]
[docs]
@PublicAPI
def global_norm(tensors: List[TensorType]) -> TensorType:
"""Returns the global L2 norm over a list of tensors.
output = sqrt(SUM(t ** 2 for t in tensors)),
where SUM reduces over all tensors and over all elements in tensors.
Args:
tensors: The list of tensors to calculate the global norm over.
Returns:
The global L2 norm over the given tensor list.
"""
# List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor.
single_l2s = [torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors]
# Compute global norm from all single tensors' L2 norms.
return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5)
[docs]
@PublicAPI
def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
"""Computes the huber loss for a given term and delta parameter.
Reference: https://en.wikipedia.org/wiki/Huber_loss
Note that the factor of 0.5 is implicitly included in the calculation.
Formula:
L = 0.5 * x^2 for small abs x (delta threshold)
L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)
Args:
x: The input term, e.g. a TD error.
delta: The delta parmameter in the above formula.
Returns:
The Huber loss resulting from `x` and `delta`.
"""
return torch.where(
torch.abs(x) < delta,
torch.pow(x, 2.0) * 0.5,
delta * (torch.abs(x) - 0.5 * delta),
)
[docs]
@PublicAPI
def l2_loss(x: TensorType) -> TensorType:
"""Computes half the L2 norm over a tensor's values without the sqrt.
output = 0.5 * sum(x ** 2)
Args:
x: The input tensor.
Returns:
0.5 times the L2 norm over the given tensor's values (w/o sqrt).
"""
return 0.5 * torch.sum(torch.pow(x, 2.0))
[docs]
@PublicAPI
def minimize_and_clip(
optimizer: "torch.optim.Optimizer", clip_val: float = 10.0
) -> None:
"""Clips grads found in `optimizer.param_groups` to given value in place.
Ensures the norm of the gradients for each variable is clipped to
`clip_val`.
Args:
optimizer: The torch.optim.Optimizer to get the variables from.
clip_val: The global norm clip value. Will clip around -clip_val and
+clip_val.
"""
# Loop through optimizer's variables and norm per variable.
for param_group in optimizer.param_groups:
for p in param_group["params"]:
if p.grad is not None:
torch.nn.utils.clip_grad_norm_(p.grad, clip_val)
[docs]
@PublicAPI
def one_hot(x: TensorType, space: gym.Space) -> TensorType:
"""Returns a one-hot tensor, given and int tensor and a space.
Handles the MultiDiscrete case as well.
Args:
x: The input tensor.
space: The space to use for generating the one-hot tensor.
Returns:
The resulting one-hot tensor.
Raises:
ValueError: If the given space is not a discrete one.
.. testcode::
import torch
import gymnasium as gym
from ray.rllib.utils.torch_utils import one_hot
x = torch.IntTensor([0, 3]) # batch-dim=2
# Discrete space with 4 (one-hot) slots per batch item.
s = gym.spaces.Discrete(4)
print(one_hot(x, s))
x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1
# MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
# per batch item.
s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
print(one_hot(x, s))
.. testoutput::
tensor([[1, 0, 0, 0],
[0, 0, 0, 1]])
tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]])
"""
if isinstance(space, Discrete):
return nn.functional.one_hot(x.long(), space.n)
elif isinstance(space, MultiDiscrete):
if isinstance(space.nvec[0], np.ndarray):
nvec = np.ravel(space.nvec)
x = x.reshape(x.shape[0], -1)
else:
nvec = space.nvec
return torch.cat(
[nn.functional.one_hot(x[:, i].long(), n) for i, n in enumerate(nvec)],
dim=-1,
)
else:
raise ValueError("Unsupported space for `one_hot`: {}".format(space))
[docs]
@PublicAPI
def reduce_mean_ignore_inf(x: TensorType, axis: Optional[int] = None) -> TensorType:
"""Same as torch.mean() but ignores -inf values.
Args:
x: The input tensor to reduce mean over.
axis: The axis over which to reduce. None for all axes.
Returns:
The mean reduced inputs, ignoring inf values.
"""
mask = torch.ne(x, float("-inf"))
x_zeroed = torch.where(mask, x, torch.zeros_like(x))
return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
[docs]
@PublicAPI
def sequence_mask(
lengths: TensorType,
maxlen: Optional[int] = None,
dtype=None,
time_major: bool = False,
) -> TensorType:
"""Offers same behavior as tf.sequence_mask for torch.
Thanks to Dimitris Papatheodorou
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
39036).
Args:
lengths: The tensor of individual lengths to mask by.
maxlen: The maximum length to use for the time axis. If None, use
the max of `lengths`.
dtype: The torch dtype to use for the resulting mask.
time_major: Whether to return the mask as [B, T] (False; default) or
as [T, B] (True).
Returns:
The sequence mask resulting from the given input and parameters.
"""
# If maxlen not given, use the longest lengths in the `lengths` tensor.
if maxlen is None:
maxlen = lengths.max()
mask = torch.ones(tuple(lengths.shape) + (int(maxlen),))
mask = ~(mask.to(lengths.device).cumsum(dim=1).t() > lengths)
# Time major transformation.
if not time_major:
mask = mask.t()
# By default, set the mask to be boolean.
mask.type(dtype or torch.bool)
return mask
@PublicAPI
def update_target_network(
main_net: NetworkType,
target_net: NetworkType,
tau: float,
) -> None:
"""Updates a torch.nn.Module target network using Polyak averaging.
new_target_net_weight = (
tau * main_net_weight + (1.0 - tau) * current_target_net_weight
)
Args:
main_net: The nn.Module to update from.
target_net: The target network to update.
tau: The tau value to use in the Polyak averaging formula.
"""
# Get the current parameters from the Q network.
state_dict = main_net.state_dict()
# Use here Polyak averaging.
new_state_dict = {
k: tau * state_dict[k] + (1 - tau) * v
for k, v in target_net.state_dict().items()
}
# Apply the new parameters to the target Q network.
target_net.load_state_dict(new_state_dict)
[docs]
@DeveloperAPI
def warn_if_infinite_kl_divergence(
policy: "TorchPolicy",
kl_divergence: TensorType,
) -> None:
if policy.loss_initialized() and kl_divergence.isinf():
logger.warning(
"KL divergence is non-finite, this will likely destabilize your model and"
" the training process. Action(s) in a specific state have near-zero"
" probability. This can happen naturally in deterministic environments"
" where the optimal policy has zero mass for a specific action. To fix this"
" issue, consider setting the coefficient for the KL loss term to zero or"
" increasing policy entropy."
)
[docs]
@PublicAPI
def set_torch_seed(seed: Optional[int] = None) -> None:
"""Sets the torch random seed to the given value.
Args:
seed: The seed to use or None for no seeding.
"""
if seed is not None and torch:
torch.manual_seed(seed)
# See https://github.com/pytorch/pytorch/issues/47672.
cuda_version = torch.version.cuda
if cuda_version is not None and float(torch.version.cuda) >= 10.2:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
else:
# Not all Operations support this.
torch.use_deterministic_algorithms(True)
# This is only for Convolution no problem.
torch.backends.cudnn.deterministic = True
[docs]
@PublicAPI
def softmax_cross_entropy_with_logits(
logits: TensorType,
labels: TensorType,
) -> TensorType:
"""Same behavior as tf.nn.softmax_cross_entropy_with_logits.
Args:
x: The input predictions.
labels: The labels corresponding to `x`.
Returns:
The resulting softmax cross-entropy given predictions and labels.
"""
return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1)
def _dynamo_is_available():
# This only works if torch._dynamo is available
try:
# TODO(Artur): Remove this once torch._dynamo is available on CI
import torch._dynamo as dynamo # noqa: F401
return True
except ImportError:
return False