Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to use the new API stack.
Model APIs#
Base Model classes#
Defines an abstract neural network model for use with RLlib. |
|
Torch version of ModelV2. |
|
TF version of ModelV2, which should contain a tf keras Model. |
Feed Forward methods#
Call the model with the given input tensors and state. |
|
Returns the value function output for the most recent forward pass. |
|
Returns the last output returned from calling the model. |
Recurrent Models API#
Get the initial recurrent state values for the model. |
|
If True, data for calling this ModelV2 must be in time-major format. |
Acessing variables#
Returns the list (or a dict) of variables for this model. |
|
Returns the list of trainable variables for this model. |
|
The base class for distribution over a random variable. |
Customization#
Override to customize the loss function used to optimize this model. |
|
Override to return custom metrics from your model. |