Note

From Ray 2.6.0 onwards, RLlib is adopting a new stack for training and model customization, gradually replacing the ModelV2 API and some convoluted parts of Policy API with the RLModule API. Click here for details.

Model APIs#

Base Model classes#

ModelV2

Defines an abstract neural network model for use with RLlib.

TorchModelV2

Torch version of ModelV2.

TFModelV2

TF version of ModelV2, which should contain a tf keras Model.

Feed Forward methods#

forward

Call the model with the given input tensors and state.

value_function

Returns the value function output for the most recent forward pass.

last_output

Returns the last output returned from calling the model.

Recurrent Models API#

get_initial_state

Get the initial recurrent state values for the model.

is_time_major

If True, data for calling this ModelV2 must be in time-major format.

Acessing variables#

variables

Returns the list (or a dict) of variables for this model.

trainable_variables

Returns the list of trainable variables for this model.

Customization#

custom_loss

Override to customize the loss function used to optimize this model.

metrics

Override to return custom metrics from your model.