Model APIs
Contents
Model APIs#
ModelV2 API (rllib.env.models.modelv2.ModelV2)#
All RLlib neural network models have to be provided as ModelV2 sub-classes.
- class ray.rllib.models.modelv2.ModelV2(obs_space: <MagicMock name='mock.Space' id='140490861621840'>, action_space: <MagicMock name='mock.Space' id='140490861621840'>, num_outputs: int, model_config: dict, name: str, framework: str)[source]#
Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.
- Data flow:
- obs -> forward() -> model_out
-> value_function() -> V(s)
- get_initial_state() List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]] [source]#
Get the initial recurrent state values for the model.
- Returns
List of np.array (for tf) or Tensor (for torch) objects containing the initial hidden state of an RNN, if applicable.
Examples
>>> import numpy as np >>> from ray.rllib.models.modelv2 import ModelV2 >>> class MyModel(ModelV2): ... # ... ... def get_initial_state(self): ... return [ ... np.zeros(self.cell_size, np.float32), ... np.zeros(self.cell_size, np.float32), ... ]
- forward(input_dict: Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]], state: List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]], seq_lens: Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]) -> (typing.Union[<built-in function array>, ForwardRef('tf.Tensor'), ForwardRef('torch.Tensor')], typing.List[typing.Union[<built-in function array>, ForwardRef('tf.Tensor'), ForwardRef('torch.Tensor')]])[source]#
Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
- Parameters
input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.
state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens – 1d tensor holding input sequence lengths
- Returns
A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.
Examples
>>> import numpy as np >>> from ray.rllib.models.modelv2 import ModelV2 >>> class MyModel(ModelV2): ... # ... >>> def forward(self, input_dict, state, seq_lens): >>> model_out, self._value_out = self.base_model( ... input_dict["obs"]) >>> return model_out, state
- value_function() Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor] [source]#
Returns the value function output for the most recent forward pass.
Note that a
forward
call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.- Returns
Value estimate tensor of shape [BATCH].
- custom_loss(policy_loss: Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], loss_inputs: Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]) Union[List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]], numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor] [source]#
Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model’s layers).
You can find an runnable example in examples/custom_loss.py.
- Parameters
policy_loss – List of or single policy loss(es) from the policy.
loss_inputs – map of input placeholders for rollout data.
- Returns
List of or scalar tensor for the customized loss(es) for this model.
- metrics() Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]] [source]#
Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e., info.learner.[policy_id, e.g. “default_policy”].model.key1=metric1
- Returns
The custom metrics for this model.
- import_from_h5(h5_file: str) None [source]#
Imports weights from an h5 file.
- Parameters
h5_file – The h5 file name to import weights from.
Example
>>> from ray.rllib.algorithms.ppo import PPO >>> trainer = PPO(...) >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") >>> for _ in range(10): >>> trainer.train()
- last_output() Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor] [source]#
Returns the last output returned from calling the model.
- context() contextlib.AbstractContextManager [source]#
Returns a contextmanager for the current forward pass.
- variables(as_dict: bool = False) Union[List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]] [source]#
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if
as_dict
is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]] [source]#
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if
as_dict
is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.
RLlib comes with two sub-classes for TF (keras) models and PyTorch models:
TFModelV2 (rllib.env.models.tf.tf_modelv2.TFModelV2)#
- class ray.rllib.models.tf.tf_modelv2.TFModelV2(obs_space: <MagicMock name='mock.spaces.Space' id='140494192538832'>, action_space: <MagicMock name='mock.spaces.Space' id='140494192538832'>, num_outputs: int, model_config: dict, name: str)[source]#
TF version of ModelV2, which should contain a tf keras Model.
Note that this class by itself is not a valid model unless you implement forward() in a subclass.
- context() contextlib.AbstractContextManager [source]#
Returns a contextmanager for the current TF graph.
- update_ops() List[Union[numpy.array, tf.Tensor, torch.Tensor]] [source]#
Return the list of update ops for this model.
For example, this should include any BatchNorm update ops.
- register_variables(variables: List[Union[numpy.array, tf.Tensor, torch.Tensor]]) None [source]#
Register the given list of variables with this model.
- variables(as_dict: bool = False) Union[List[Union[numpy.array, tf.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]]] [source]#
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if
as_dict
is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Union[numpy.array, tf.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]]] [source]#
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if
as_dict
is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.
TorchModelV2 (rllib.env.models.torch.torch_modelv2.TorchModelV2)#
- class ray.rllib.models.torch.torch_modelv2.TorchModelV2(obs_space: <MagicMock name='mock.spaces.Space' id='140494192538832'>, action_space: <MagicMock name='mock.spaces.Space' id='140494192538832'>, num_outputs: int, model_config: dict, name: str)[source]#
Torch version of ModelV2.
Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.
- variables(as_dict: bool = False) Union[List[Union[numpy.array, tf.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]]] [source]#
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if
as_dict
is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Union[numpy.array, tf.Tensor, torch.Tensor]], Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]]] [source]#
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if
as_dict
is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.