ray.rllib.models.modelv2.ModelV2#

class ray.rllib.models.modelv2.ModelV2(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: dict, name: str, framework: str)[source]#

Defines an abstract neural network model for use with RLlib.

Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.

Data flow:
obs -> forward() -> model_out

-> value_function() -> V(s)

Methods

__init__

Initializes a ModelV2 instance.

context

Returns a contextmanager for the current forward pass.

custom_loss

Override to customize the loss function used to optimize this model.

forward

Call the model with the given input tensors and state.

get_initial_state

Get the initial recurrent state values for the model.

import_from_h5

Imports weights from an h5 file.

is_time_major

If True, data for calling this ModelV2 must be in time-major format.

last_output

Returns the last output returned from calling the model.

metrics

Override to return custom metrics from your model.

trainable_variables

Returns the list of trainable variables for this model.

value_function

Returns the value function output for the most recent forward pass.

variables

Returns the list (or a dict) of variables for this model.