ray.rllib.utils.tf_utils.one_hot#

ray.rllib.utils.tf_utils.one_hot(x: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, space: gymnasium.Space) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Returns a one-hot tensor, given and int tensor and a space.

Handles the MultiDiscrete case as well.

Parameters:
  • x – The input tensor.

  • space – The space to use for generating the one-hot tensor.

Returns:

The resulting one-hot tensor.

Raises:

ValueError – If the given space is not a discrete one.

import gymnasium as gym
import tensorflow as tf
from ray.rllib.utils.tf_utils import one_hot
x = tf.Variable([0, 3], dtype=tf.int32)  # batch-dim=2
# Discrete space with 4 (one-hot) slots per batch item.
s = gym.spaces.Discrete(4)
one_hot(x, s)
<tf.Tensor 'one_hot:0' shape=(2, 4) dtype=float32>
x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32)  # batch-dim=1
# MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
# per batch item.
s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
one_hot(x, s)
<tf.Tensor 'concat:0' shape=(1, 20) dtype=float32>