PyTorch Utility Functions#

ray.rllib.utils.torch_utils.apply_grad_clipping(policy: TorchPolicy, optimizer: Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer], loss: Union[numpy.array, tf.Tensor, torch.Tensor]) Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]][source]#

Applies gradient clipping to already computed grads inside optimizer.

Parameters
  • policy – The TorchPolicy, which calculated loss.

  • optimizer – A local torch optimizer object.

  • loss – The torch loss tensor.

Returns

An info dict containing the “grad_norm” key and the resulting clipped gradients.

ray.rllib.utils.torch_utils.concat_multi_gpu_td_errors(policy: Union[TorchPolicy, TorchPolicyV2]) Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]][source]#

Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.

TD-errors are extracted from the TorchPolicy via its tower_stats property.

Parameters

policy – The TorchPolicy to extract the TD-error values from.

Returns

A dict mapping strings “td_error” and “mean_td_error” to the corresponding concatenated and mean-reduced values.

ray.rllib.utils.torch_utils.convert_to_torch_tensor(x: Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple], device: Optional[str] = None)[source]#

Converts any struct to torch.Tensors.

x: Any (possibly nested) struct, the values in which will be

converted and returned as a new struct with all leaves converted to torch tensors.

Returns

A new struct with the same structure as x, but with all

values converted to torch Tensor types. This does not convert possibly nested elements that are None because torch has no representation for that.

Return type

Any

ray.rllib.utils.torch_utils.explained_variance(y: Union[numpy.array, tf.Tensor, torch.Tensor], pred: Union[numpy.array, tf.Tensor, torch.Tensor]) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Computes the explained variance for a pair of labels and predictions.

The formula used is: max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))

Parameters
  • y – The labels.

  • pred – The predictions.

Returns

The explained variance given a pair of labels and predictions.

ray.rllib.utils.torch_utils.flatten_inputs_to_1d_tensor(inputs: Union[numpy.array, tf.Tensor, torch.Tensor, dict, tuple], spaces_struct: Optional[Union[<MagicMock name='mock.spaces.Space' id='140494192538832'>, dict, tuple]] = None, time_axis: bool = False) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Flattens arbitrary input structs according to the given spaces struct.

Returns a single 1D tensor resulting from the different input components’ values.

Thereby: - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes are not treated differently from other types of Boxes and get flattened as well. - Discrete (int) values are one-hot’d, e.g. a batch of [1, 0, 3] (B=3 with Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]. - MultiDiscrete values are multi-one-hot’d, e.g. a batch of [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].

Parameters
  • inputs – The inputs to be flattened.

  • spaces_struct – The structure of the spaces that behind the input

  • time_axis – Whether all inputs have a time-axis (after the batch axis). If True, will keep not only the batch axis (0th), but the time axis (1st) as-is and flatten everything from the 2nd axis up.

Returns

A single 1D tensor resulting from concatenating all flattened/one-hot’d input components. Depending on the time_axis flag, the shape is (B, n) or (B, T, n).

Examples

>>> # B=2
>>> from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor
>>> from gymnasium.spaces import Discrete, Box
>>> out = flatten_inputs_to_1d_tensor( 
...     {"a": [1, 0], "b": [[[0.0], [0.1]], [1.0], [1.1]]},
...     spaces_struct=dict(a=Discrete(2), b=Box(shape=(2, 1))))
... ) 
>>> print(out) 
[[0.0, 1.0,  0.0, 0.1], [1.0, 0.0,  1.0, 1.1]]  # B=2 n=4
>>> # B=2; T=2
>>> out = flatten_inputs_to_1d_tensor( 
...     ([[1, 0], [0, 1]],
...      [[[0.0, 0.1], [1.0, 1.1]], [[2.0, 2.1], [3.0, 3.1]]]),
...     spaces_struct=tuple([Discrete(2), Box(shape=(2, ))]),
...     time_axis=True
... ) 
>>> print(out) 
[[[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]],        [[1.0, 0.0, 2.0, 2.1], [0.0, 1.0, 3.0, 3.1]]]  # B=2 T=2 n=4
ray.rllib.utils.torch_utils.get_device(config)[source]#

Returns a torch device edepending on a config and current worker index.

ray.rllib.utils.torch_utils.global_norm(tensors: List[Union[numpy.array, tf.Tensor, torch.Tensor]]) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Returns the global L2 norm over a list of tensors.

output = sqrt(SUM(t ** 2 for t in tensors)),

where SUM reduces over all tensors and over all elements in tensors.

Parameters

tensors – The list of tensors to calculate the global norm over.

Returns

The global L2 norm over the given tensor list.

ray.rllib.utils.torch_utils.huber_loss(x: Union[numpy.array, tf.Tensor, torch.Tensor], delta: float = 1.0) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Computes the huber loss for a given term and delta parameter.

Reference: https://en.wikipedia.org/wiki/Huber_loss Note that the factor of 0.5 is implicitly included in the calculation.

Formula:

L = 0.5 * x^2 for small abs x (delta threshold) L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)

Parameters
  • x – The input term, e.g. a TD error.

  • delta – The delta parmameter in the above formula.

Returns

The Huber loss resulting from x and delta.

ray.rllib.utils.torch_utils.l2_loss(x: Union[numpy.array, tf.Tensor, torch.Tensor]) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Computes half the L2 norm over a tensor’s values without the sqrt.

output = 0.5 * sum(x ** 2)

Parameters

x – The input tensor.

Returns

0.5 times the L2 norm over the given tensor’s values (w/o sqrt).

ray.rllib.utils.torch_utils.minimize_and_clip(optimizer: torch.optim.optimizer.Optimizer, clip_val: float = 10.0) None[source]#

Clips grads found in optimizer.param_groups to given value in place.

Ensures the norm of the gradients for each variable is clipped to clip_val.

Parameters
  • optimizer – The torch.optim.Optimizer to get the variables from.

  • clip_val – The global norm clip value. Will clip around -clip_val and +clip_val.

ray.rllib.utils.torch_utils.one_hot(x: Union[numpy.array, tf.Tensor, torch.Tensor], space: <MagicMock name='mock.Space' id='140490912533392'>) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Returns a one-hot tensor, given and int tensor and a space.

Handles the MultiDiscrete case as well.

Parameters
  • x – The input tensor.

  • space – The space to use for generating the one-hot tensor.

Returns

The resulting one-hot tensor.

Raises

ValueError – If the given space is not a discrete one.

Examples

>>> import torch
>>> import gymnasium as gym
>>> from ray.rllib.utils.torch_utils import one_hot
>>> x = torch.IntTensor([0, 3])  # batch-dim=2
>>> # Discrete space with 4 (one-hot) slots per batch item.
>>> s = gym.spaces.Discrete(4)
>>> one_hot(x, s) 
tensor([[1, 0, 0, 0], [0, 0, 0, 1]])
>>> x = torch.IntTensor([[0, 1, 2, 3]])  # batch-dim=1
>>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
>>> # per batch item.
>>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
>>> one_hot(x, s) 
tensor([[1, 0, 0, 0, 0,
         0, 1, 0, 0,
         0, 0, 1, 0,
         0, 0, 0, 1, 0, 0, 0]])
ray.rllib.utils.torch_utils.reduce_mean_ignore_inf(x: Union[numpy.array, tf.Tensor, torch.Tensor], axis: Optional[int] = None) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Same as torch.mean() but ignores -inf values.

Parameters
  • x – The input tensor to reduce mean over.

  • axis – The axis over which to reduce. None for all axes.

Returns

The mean reduced inputs, ignoring inf values.

ray.rllib.utils.torch_utils.sequence_mask(lengths: Union[numpy.array, tf.Tensor, torch.Tensor], maxlen: Optional[int] = None, dtype=None, time_major: bool = False) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Offers same behavior as tf.sequence_mask for torch.

Thanks to Dimitris Papatheodorou (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ 39036).

Parameters
  • lengths – The tensor of individual lengths to mask by.

  • maxlen – The maximum length to use for the time axis. If None, use the max of lengths.

  • dtype – The torch dtype to use for the resulting mask.

  • time_major – Whether to return the mask as [B, T] (False; default) or as [T, B] (True).

Returns

The sequence mask resulting from the given input and parameters.

ray.rllib.utils.torch_utils.set_torch_seed(seed: Optional[int] = None) None[source]#

Sets the torch random seed to the given value.

Parameters

seed – The seed to use or None for no seeding.

ray.rllib.utils.torch_utils.softmax_cross_entropy_with_logits(logits: Union[numpy.array, tf.Tensor, torch.Tensor], labels: Union[numpy.array, tf.Tensor, torch.Tensor]) Union[numpy.array, tf.Tensor, torch.Tensor][source]#

Same behavior as tf.nn.softmax_cross_entropy_with_logits.

Parameters
  • x – The input predictions.

  • labels – The labels corresponding to x.

Returns

The resulting softmax cross-entropy given predictions and labels.

class ray.rllib.utils.torch_utils.Swish[source]#
forward(input_tensor)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.