ray.rllib.models.modelv2.ModelV2
ray.rllib.models.modelv2.ModelV2#
- class ray.rllib.models.modelv2.ModelV2(obs_space: <MagicMock name='mock.Space' id='139917731741904'>, action_space: <MagicMock name='mock.Space' id='139917731741904'>, num_outputs: int, model_config: dict, name: str, framework: str)[source]#
Bases:
object
Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.
- Data flow:
- obs -> forward() -> model_out
-> value_function() -> V(s)
Methods
__init__
(obs_space, action_space, ...)Initializes a ModelV2 instance.
context
()Returns a contextmanager for the current forward pass.
custom_loss
(policy_loss, loss_inputs)Override to customize the loss function used to optimize this model.
forward
(input_dict, state, seq_lens)Call the model with the given input tensors and state.
Get the initial recurrent state values for the model.
import_from_h5
(h5_file)Imports weights from an h5 file.
If True, data for calling this ModelV2 must be in time-major format.
Returns the last output returned from calling the model.
metrics
()Override to return custom metrics from your model.
trainable_variables
([as_dict])Returns the list of trainable variables for this model.
Returns the value function output for the most recent forward pass.
variables
([as_dict])Returns the list (or a dict) of variables for this model.