ray.rllib.utils.numpy.one_hot#

ray.rllib.utils.numpy.one_hot(x: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | int, depth: int = 0, on_value: float = 1.0, off_value: float = 0.0, dtype: type = numpy.float32) numpy.ndarray[source]#

One-hot utility function for numpy.

Thanks to qianyizhang: https://gist.github.com/qianyizhang/07ee1c15cad08afb03f5de69349efc30.

Parameters:
  • x – The input to be one-hot encoded.

  • depth – The max. number to be one-hot encoded (size of last rank).

  • on_value – The value to use for on. Default: 1.0.

  • off_value – The value to use for off. Default: 0.0.

Returns:

The one-hot encoded equivalent of the input array.