ray.rllib.utils.torch_utils.sequence_mask#
- ray.rllib.utils.torch_utils.sequence_mask(lengths: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, maxlen: int | None = None, dtype=None, time_major: bool = False) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor [source]#
Offers same behavior as tf.sequence_mask for torch.
Thanks to Dimitris Papatheodorou (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/ 39036).
- Parameters:
lengths – The tensor of individual lengths to mask by.
maxlen – The maximum length to use for the time axis. If None, use the max of
lengths
.dtype – The torch dtype to use for the resulting mask.
time_major – Whether to return the mask as [B, T] (False; default) or as [T, B] (True).
- Returns:
The sequence mask resulting from the given input and parameters.