ray.rllib.models.tf.tf_modelv2.TFModelV2#
- class ray.rllib.models.tf.tf_modelv2.TFModelV2(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: dict, name: str)[source]#
Bases:
ModelV2
TF version of ModelV2, which should contain a tf keras Model.
Note that this class by itself is not a valid model unless you implement forward() in a subclass.
Methods
Initializes a TFModelV2 instance.
Returns a contextmanager for the current TF graph.
Override to customize the loss function used to optimize this model.
Call the model with the given input tensors and state.
Get the initial recurrent state values for the model.
If True, data for calling this ModelV2 must be in time-major format.
Returns the last output returned from calling the model.
Override to return custom metrics from your model.
Register the given list of variables with this model.
Return the list of update ops for this model.
Returns the value function output for the most recent forward pass.