ray.rllib.policy.torch_policy_v2.TorchPolicyV2.maybe_add_time_dimension#

TorchPolicyV2.maybe_add_time_dimension(input_dict: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], seq_lens: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, framework: str = None)#

Adds a time dimension for recurrent RLModules.

Parameters:
  • input_dict – The input dict.

  • seq_lens – The sequence lengths.

  • framework – The framework to use for adding the time dimensions. If None, will default to the framework of the policy.

Returns:

The input dict, with a possibly added time dimension.