ray.rllib.models.modelv2.ModelV2.forward#

ModelV2.forward(input_dict: Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]], state: List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]], seq_lens: Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor])[source]#

Call the model with the given input tensors and state.

Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].

This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).

Custom models should override this instead of __call__.

Parameters
  • input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.

  • state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension

  • seq_lens – 1d tensor holding input sequence lengths

Returns

A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.

Examples

>>> import numpy as np
>>> from ray.rllib.models.modelv2 import ModelV2
>>> class MyModel(ModelV2): 
...     # ...
>>>     def forward(self, input_dict, state, seq_lens):
>>>         model_out, self._value_out = self.base_model(
...             input_dict["obs"])
>>>         return model_out, state