ray.rllib.models.modelv2.ModelV2#
- class ray.rllib.models.modelv2.ModelV2(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: dict, name: str, framework: str)[source]#
Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.
- Data flow:
- obs -> forward() -> model_out
-> value_function() -> V(s)
Methods
Initializes a ModelV2 instance.
Returns a contextmanager for the current forward pass.
Override to customize the loss function used to optimize this model.
Call the model with the given input tensors and state.
Get the initial recurrent state values for the model.
If True, data for calling this ModelV2 must be in time-major format.
Returns the last output returned from calling the model.
Override to return custom metrics from your model.
Returns the list of trainable variables for this model.
Returns the value function output for the most recent forward pass.
Returns the list (or a dict) of variables for this model.