ray.rllib.models.torch.torch_modelv2.TorchModelV2#
- class ray.rllib.models.torch.torch_modelv2.TorchModelV2(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: dict, name: str)[source]#
Bases:
ModelV2
Torch version of ModelV2.
Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.
Methods
Initialize a TorchModelV2.
Returns a contextmanager for the current forward pass.
Override to customize the loss function used to optimize this model.
Call the model with the given input tensors and state.
Get the initial recurrent state values for the model.
If True, data for calling this ModelV2 must be in time-major format.
Returns the last output returned from calling the model.
Override to return custom metrics from your model.
Returns the value function output for the most recent forward pass.