Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.
Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.
See here for more details on how to use the new API stack.
Model APIs#
Base Model classes#
Defines an abstract neural network model for use with RLlib. |
|
Torch version of ModelV2. |
|
TF version of ModelV2, which should contain a tf keras Model. |
Feed Forward methods#
Call the model with the given input tensors and state. |
|
Returns the value function output for the most recent forward pass. |
|
Returns the last output returned from calling the model. |
Recurrent Models API#
Get the initial recurrent state values for the model. |
|
If True, data for calling this ModelV2 must be in time-major format. |
Acessing variables#
Returns the list (or a dict) of variables for this model. |
|
Returns the list of trainable variables for this model. |
|
The base class for distribution over a random variable. |
Customization#
Override to customize the loss function used to optimize this model. |
|
Override to return custom metrics from your model. |