ray.rllib.utils.torch_utils.one_hot#

ray.rllib.utils.torch_utils.one_hot(x: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, space: gymnasium.Space) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Returns a one-hot tensor, given and int tensor and a space.

Handles the MultiDiscrete case as well.

Parameters:
  • x – The input tensor.

  • space – The space to use for generating the one-hot tensor.

Returns:

The resulting one-hot tensor.

Raises:

ValueError – If the given space is not a discrete one.

import torch
import gymnasium as gym
from ray.rllib.utils.torch_utils import one_hot
x = torch.IntTensor([0, 3])  # batch-dim=2
# Discrete space with 4 (one-hot) slots per batch item.
s = gym.spaces.Discrete(4)
print(one_hot(x, s))
x = torch.IntTensor([[0, 1, 2, 3]])  # batch-dim=1
# MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
# per batch item.
s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
print(one_hot(x, s))
tensor([[1, 0, 0, 0],
        [0, 0, 0, 1]])
tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]])