ray.rllib.utils.tf_utils.one_hot#
- ray.rllib.utils.tf_utils.one_hot(x: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, space: gymnasium.Space) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor [source]#
Returns a one-hot tensor, given and int tensor and a space.
Handles the MultiDiscrete case as well.
- Parameters:
x – The input tensor.
space – The space to use for generating the one-hot tensor.
- Returns:
The resulting one-hot tensor.
- Raises:
ValueError – If the given space is not a discrete one.
import gymnasium as gym import tensorflow as tf from ray.rllib.utils.tf_utils import one_hot x = tf.Variable([0, 3], dtype=tf.int32) # batch-dim=2 # Discrete space with 4 (one-hot) slots per batch item. s = gym.spaces.Discrete(4) one_hot(x, s)
<tf.Tensor 'one_hot:0' shape=(2, 4) dtype=float32>
x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32) # batch-dim=1 # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots # per batch item. s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) one_hot(x, s)
<tf.Tensor 'concat:0' shape=(1, 20) dtype=float32>