Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

See here for more details on how to use the new API stack.

Model APIs#

Base Model classes#

ModelV2

Defines an abstract neural network model for use with RLlib.

TorchModelV2

Torch version of ModelV2.

TFModelV2

TF version of ModelV2, which should contain a tf keras Model.

Feed Forward methods#

forward

Call the model with the given input tensors and state.

value_function

Returns the value function output for the most recent forward pass.

last_output

Returns the last output returned from calling the model.

Recurrent Models API#

get_initial_state

Get the initial recurrent state values for the model.

is_time_major

If True, data for calling this ModelV2 must be in time-major format.

Acessing variables#

variables

Returns the list (or a dict) of variables for this model.

trainable_variables

Returns the list of trainable variables for this model.

Distribution

The base class for distribution over a random variable.

Customization#

custom_loss

Override to customize the loss function used to optimize this model.

metrics

Override to return custom metrics from your model.