ray.rllib.models.torch.torch_modelv2.TorchModelV2.forward#

TorchModelV2.forward(input_dict: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], state: List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], seq_lens: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor)#

Call the model with the given input tensors and state.

Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].

This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).

Custom models should override this instead of __call__.

Parameters:
  • input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.

  • state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension

  • seq_lens – 1d tensor holding input sequence lengths

Returns:

A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.

import numpy as np
from ray.rllib.models.modelv2 import ModelV2
class MyModel(ModelV2):
    # ...
    def forward(self, input_dict, state, seq_lens):
        model_out, self._value_out = self.base_model(
            input_dict["obs"])
        return model_out, state