Model APIs¶
ModelV2 API (rllib.env.models.modelv2.ModelV2)¶
All RLlib neural network models have to be provided as ModelV2 sub-classes.
- class ray.rllib.models.modelv2.ModelV2(obs_space: <Mock name='mock.Space' id='140585395326032'>, action_space: <Mock name='mock.Space' id='140585395326032'>, num_outputs: int, model_config: dict, name: str, framework: str)[source]¶
Defines an abstract neural network model for use with RLlib.
Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.
- Data flow:
- obs -> forward() -> model_out
-> value_function() -> V(s)
- get_initial_state() List[numpy.ndarray] [source]¶
Get the initial recurrent state values for the model.
- Returns
List of np.array objects containing the initial hidden state of an RNN, if applicable.
Examples
>>> def get_initial_state(self): >>> return [ >>> np.zeros(self.cell_size, np.float32), >>> np.zeros(self.cell_size, np.float32), >>> ]
- forward(input_dict: Dict[str, Any], state: List[Any], seq_lens: Any)[source]¶
Call the model with the given input tensors and state.
Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].
This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).
Custom models should override this instead of __call__.
- Parameters
input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.
state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens – 1d tensor holding input sequence lengths
- Returns
A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.
Examples
>>> def forward(self, input_dict, state, seq_lens): >>> model_out, self._value_out = self.base_model( ... input_dict["obs"]) >>> return model_out, state
- value_function() Any [source]¶
Returns the value function output for the most recent forward pass.
Note that a forward call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.
- Returns
Value estimate tensor of shape [BATCH].
- custom_loss(policy_loss: Any, loss_inputs: Dict[str, Any]) Union[List[Any], Any] [source]¶
Override to customize the loss function used to optimize this model.
This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model’s layers).
You can find an runnable example in examples/custom_loss.py.
- Parameters
policy_loss – List of or single policy loss(es) from the policy.
loss_inputs – map of input placeholders for rollout data.
- Returns
List of or scalar tensor for the customized loss(es) for this model.
- metrics() Dict[str, Any] [source]¶
Override to return custom metrics from your model.
The stats will be reported as part of the learner stats, i.e., info.learner.[policy_id, e.g. “default_policy”].model.key1=metric1
- Returns
The custom metrics for this model.
- import_from_h5(h5_file: str) None [source]¶
Imports weights from an h5 file.
- Parameters
h5_file – The h5 file name to import weights from.
Example
>>> trainer = MyTrainer() >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") >>> for _ in range(10): >>> trainer.train()
- context() contextlib.AbstractContextManager [source]¶
Returns a contextmanager for the current forward pass.
- variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if as_dict is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.
RLlib comes with two sub-classes for TF (keras) models and PyTorch models:
TFModelV2 (rllib.env.models.tf.tf_modelv2.TFModelV2)¶
- class ray.rllib.models.tf.tf_modelv2.TFModelV2(obs_space: <Mock name='mock.spaces.Space' id='140585386409680'>, action_space: <Mock name='mock.spaces.Space' id='140585386409680'>, num_outputs: int, model_config: dict, name: str)[source]¶
TF version of ModelV2, which should contain a tf keras Model.
Note that this class by itself is not a valid model unless you implement forward() in a subclass.
- context() contextlib.AbstractContextManager [source]¶
Returns a contextmanager for the current TF graph.
- update_ops() List[Any] [source]¶
Return the list of update ops for this model.
For example, this should include any BatchNorm update ops.
- register_variables(variables: List[Any]) None [source]¶
Register the given list of variables with this model.
- variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if as_dict is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.
TorchModelV2 (rllib.env.models.torch.torch_modelv2.TorchModelV2)¶
- class ray.rllib.models.torch.torch_modelv2.TorchModelV2(obs_space: <Mock name='mock.spaces.Space' id='140585386409680'>, action_space: <Mock name='mock.spaces.Space' id='140585386409680'>, num_outputs: int, model_config: dict, name: str)[source]¶
Torch version of ModelV2.
Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.
- variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list (or a dict) of variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive str keys).
- Returns
The list (or dict if as_dict is True) of all variables of this ModelV2.
- trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]] [source]¶
Returns the list of trainable variables for this model.
- Parameters
as_dict – Whether variables should be returned as dict-values (using descriptive keys).
- Returns
The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.