ray.rllib.utils.torch_utils.sequence_mask#

ray.rllib.utils.torch_utils.sequence_mask(lengths: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, maxlen: int | None = None, dtype=None, time_major: bool = False) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Offers same behavior as tf.sequence_mask for torch.

Thanks to Dimitris Papatheodorou (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ 39036).

Parameters:
  • lengths – The tensor of individual lengths to mask by.

  • maxlen – The maximum length to use for the time axis. If None, use the max of lengths.

  • dtype – The torch dtype to use for the resulting mask.

  • time_major – Whether to return the mask as [B, T] (False; default) or as [T, B] (True).

Returns:

The sequence mask resulting from the given input and parameters.