ray.rllib.utils.torch_utils.one_hot#
- ray.rllib.utils.torch_utils.one_hot(x: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, space: gymnasium.Space) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor [source]#
Returns a one-hot tensor, given and int tensor and a space.
Handles the MultiDiscrete case as well.
- Parameters:
x – The input tensor.
space – The space to use for generating the one-hot tensor.
- Returns:
The resulting one-hot tensor.
- Raises:
ValueError – If the given space is not a discrete one.
import torch import gymnasium as gym from ray.rllib.utils.torch_utils import one_hot x = torch.IntTensor([0, 3]) # batch-dim=2 # Discrete space with 4 (one-hot) slots per batch item. s = gym.spaces.Discrete(4) print(one_hot(x, s)) x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1 # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots # per batch item. s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) print(one_hot(x, s))
tensor([[1, 0, 0, 0], [0, 0, 0, 1]]) tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]])