Model APIs

ModelV2 API (rllib.env.models.modelv2.ModelV2)

All RLlib neural network models have to be provided as ModelV2 sub-classes.

class ray.rllib.models.modelv2.ModelV2(obs_space: <MagicMock name='mock.Space' id='140644791813392'>, action_space: <MagicMock name='mock.Space' id='140644791813392'>, num_outputs: int, model_config: dict, name: str, framework: str)[source]

Defines an abstract neural network model for use with RLlib.

Custom models should extend either TFModelV2 or TorchModelV2 instead of this class directly.

Data flow:
obs -> forward() -> model_out

-> value_function() -> V(s)

get_initial_state() List[numpy.ndarray][source]

Get the initial recurrent state values for the model.

Returns

List of np.array objects containing the initial hidden state of an RNN, if applicable.

Examples

>>> import numpy as np
>>> from ray.rllib.models.modelv2 import ModelV2
>>> class MyModel(ModelV2): 
...     # ...
...     def get_initial_state(self):
...         return [
...             np.zeros(self.cell_size, np.float32),
...             np.zeros(self.cell_size, np.float32),
...         ]
forward(input_dict: Dict[str, Any], state: List[Any], seq_lens: Any)[source]

Call the model with the given input tensors and state.

Any complex observations (dicts, tuples, etc.) will be unpacked by __call__ before being passed to forward(). To access the flattened observation tensor, refer to input_dict[“obs_flat”].

This method can be called any number of times. In eager execution, each call to forward() will eagerly evaluate the model. In symbolic execution, each call to forward creates a computation graph that operates over the variables of this model (i.e., shares weights).

Custom models should override this instead of __call__.

Parameters
  • input_dict – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.

  • state – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension

  • seq_lens – 1d tensor holding input sequence lengths

Returns

A tuple consisting of the model output tensor of size [BATCH, num_outputs] and the list of new RNN state(s) if any.

Examples

>>> import numpy as np
>>> from ray.rllib.models.modelv2 import ModelV2
>>> class MyModel(ModelV2): 
...     # ...
>>>     def forward(self, input_dict, state, seq_lens):
>>>         model_out, self._value_out = self.base_model(
...             input_dict["obs"])
>>>         return model_out, state
value_function() Any[source]

Returns the value function output for the most recent forward pass.

Note that a forward call has to be performed first, before this methods can return anything and thus that calling this method does not cause an extra forward pass through the network.

Returns

Value estimate tensor of shape [BATCH].

custom_loss(policy_loss: Any, loss_inputs: Dict[str, Any]) Union[List[Any], Any][source]

Override to customize the loss function used to optimize this model.

This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model’s layers).

You can find an runnable example in examples/custom_loss.py.

Parameters
  • policy_loss – List of or single policy loss(es) from the policy.

  • loss_inputs – map of input placeholders for rollout data.

Returns

List of or scalar tensor for the customized loss(es) for this model.

metrics() Dict[str, Any][source]

Override to return custom metrics from your model.

The stats will be reported as part of the learner stats, i.e., info.learner.[policy_id, e.g. “default_policy”].model.key1=metric1

Returns

The custom metrics for this model.

import_from_h5(h5_file: str) None[source]

Imports weights from an h5 file.

Parameters

h5_file – The h5 file name to import weights from.

Example

>>> from ray.rllib.algorithms.ppo import PPO
>>> trainer = PPO(...)  
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5") 
>>> for _ in range(10): 
>>>     trainer.train() 
last_output() Any[source]

Returns the last output returned from calling the model.

context() contextlib.AbstractContextManager[source]

Returns a contextmanager for the current forward pass.

variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list (or a dict) of variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive str keys).

Returns

The list (or dict if as_dict is True) of all variables of this ModelV2.

trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list of trainable variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive keys).

Returns

The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.

is_time_major() bool[source]

If True, data for calling this ModelV2 must be in time-major format.

Returns

Whether this ModelV2 requires a time-major (TxBx…) data format.

RLlib comes with two sub-classes for TF (keras) models and PyTorch models:

TFModelV2 (rllib.env.models.tf.tf_modelv2.TFModelV2)

class ray.rllib.models.tf.tf_modelv2.TFModelV2(obs_space: <MagicMock name='mock.spaces.Space' id='140644794755728'>, action_space: <MagicMock name='mock.spaces.Space' id='140644794755728'>, num_outputs: int, model_config: dict, name: str)[source]

TF version of ModelV2, which should contain a tf keras Model.

Note that this class by itself is not a valid model unless you implement forward() in a subclass.

context() contextlib.AbstractContextManager[source]

Returns a contextmanager for the current TF graph.

update_ops() List[Any][source]

Return the list of update ops for this model.

For example, this should include any BatchNorm update ops.

register_variables(variables: List[Any]) None[source]

Register the given list of variables with this model.

variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list (or a dict) of variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive str keys).

Returns

The list (or dict if as_dict is True) of all variables of this ModelV2.

trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list of trainable variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive keys).

Returns

The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.

TorchModelV2 (rllib.env.models.torch.torch_modelv2.TorchModelV2)

class ray.rllib.models.torch.torch_modelv2.TorchModelV2(obs_space: <MagicMock name='mock.spaces.Space' id='140644794755728'>, action_space: <MagicMock name='mock.spaces.Space' id='140644794755728'>, num_outputs: int, model_config: dict, name: str)[source]

Torch version of ModelV2.

Note that this class by itself is not a valid model unless you inherit from nn.Module and implement forward() in a subclass.

variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list (or a dict) of variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive str keys).

Returns

The list (or dict if as_dict is True) of all variables of this ModelV2.

trainable_variables(as_dict: bool = False) Union[List[Any], Dict[str, Any]][source]

Returns the list of trainable variables for this model.

Parameters

as_dict – Whether variables should be returned as dict-values (using descriptive keys).

Returns

The list (or dict if as_dict is True) of all trainable (tf)/requires_grad (torch) variables of this ModelV2.