- ModelV2.forward(input_dict: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], state: List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], seq_lens: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor)#
Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.
state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens – 1d tensor holding input sequence lengths
A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.
import numpy as np from ray.rllib.models.modelv2 import ModelV2 class MyModel(ModelV2): # ... def forward(self, input_dict, state, seq_lens): model_out, self._value_out = self.base_model( input_dict["obs"]) return model_out, state