Source code for ray.rllib.models.model

from collections import OrderedDict
import logging
import gym

from import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf, try_import_torch

tf = try_import_tf()
torch, _ = try_import_torch()

logger = logging.getLogger(__name__)

[docs]class Model: """This class is deprecated! Use ModelV2 instead.""" def __init__(self, input_dict, obs_space, action_space, num_outputs, options, state_in=None, seq_lens=None): # Soft-deprecate this class. All Models should use the ModelV2 # API from here on. deprecation_warning("Model", "ModelV2", error=False) assert isinstance(input_dict, dict), input_dict # Default attribute values for the non-RNN case self.state_init = [] self.state_in = state_in or [] self.state_out = [] self.obs_space = obs_space self.action_space = action_space self.num_outputs = num_outputs self.options = options self.scope = tf.get_variable_scope() self.session = tf.get_default_session() self.input_dict = input_dict if seq_lens is not None: self.seq_lens = seq_lens else: self.seq_lens = tf.placeholder( dtype=tf.int32, shape=[None], name="seq_lens") self._num_outputs = num_outputs if options.get("free_log_std"): assert num_outputs % 2 == 0 num_outputs = num_outputs // 2 ok = True try: restored = input_dict.copy() restored["obs"] = restore_original_dimensions( input_dict["obs"], obs_space) self.outputs, self.last_layer = self._build_layers_v2( restored, num_outputs, options) except NotImplementedError: ok = False # In TF 1.14, you cannot construct variable scopes in exception # handlers so we have to set the OK flag and check it here: if not ok: self.outputs, self.last_layer = self._build_layers( input_dict["obs"], num_outputs, options) if options.get("free_log_std", False): log_std = tf.get_variable( name="log_std", shape=[num_outputs], initializer=tf.zeros_initializer) self.outputs = tf.concat( [self.outputs, 0.0 * self.outputs + log_std], 1) def _build_layers(self, inputs, num_outputs, options): """Builds and returns the output and last layer of the network. Deprecated: use _build_layers_v2 instead, which has better support for dict and tuple spaces. """ raise NotImplementedError @PublicAPI def _build_layers_v2(self, input_dict, num_outputs, options): """Define the layers of a custom model. Arguments: input_dict (dict): Dictionary of input tensors, including "obs", "prev_action", "prev_reward", "is_training". num_outputs (int): Output tensor must be of size [BATCH_SIZE, num_outputs]. options (dict): Model options. Returns: (outputs, feature_layer): Tensors of size [BATCH_SIZE, num_outputs] and [BATCH_SIZE, desired_feature_size]. When using dict or tuple observation spaces, you can access the nested sub-observation batches here as well: Examples: >>> print(input_dict) {'prev_actions': <tf.Tensor shape=(?,) dtype=int64>, 'prev_rewards': <tf.Tensor shape=(?,) dtype=float32>, 'is_training': <tf.Tensor shape=(), dtype=bool>, 'obs': OrderedDict([ ('sensors', OrderedDict([ ('front_cam', [ <tf.Tensor shape=(?, 10, 10, 3) dtype=float32>, <tf.Tensor shape=(?, 10, 10, 3) dtype=float32>]), ('position', <tf.Tensor shape=(?, 3) dtype=float32>), ('velocity', <tf.Tensor shape=(?, 3) dtype=float32>)]))])} """ raise NotImplementedError
[docs] @PublicAPI def value_function(self): """Builds the value function output. This method can be overridden to customize the implementation of the value function (e.g., not sharing hidden layers). Returns: Tensor of size [BATCH_SIZE] for the value function. """ return tf.reshape( linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
[docs] @PublicAPI def custom_loss(self, policy_loss, loss_inputs): """Override to customize the loss function used to optimize this model. This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model's layers). You can find an runnable example in examples/ Arguments: policy_loss (Tensor): scalar policy loss from the policy. loss_inputs (dict): map of input placeholders for rollout data. Returns: Scalar tensor for the customized loss for this model. """ if self.loss() is not None: raise DeprecationWarning( "self.loss() is deprecated, use self.custom_loss() instead.") return policy_loss
[docs] @PublicAPI def custom_stats(self): """Override to return custom metrics from your model. The stats will be reported as part of the learner stats, i.e., info: learner: model: key1: metric1 key2: metric2 Returns: Dict of string keys to scalar tensors. """ return {}
[docs] def loss(self): """Deprecated: use self.custom_loss().""" return None
@classmethod def get_initial_state(cls, obs_space, action_space, num_outputs, options): raise NotImplementedError( "In order to use recurrent models with ModelV2, you should define " "the get_initial_state @classmethod on your custom model class.") def _validate_output_shape(self): """Checks that the model has the correct number of outputs.""" try: out = tf.convert_to_tensor(self.outputs) shape = out.shape.as_list() except Exception: raise ValueError("Output is not a tensor: {}".format(self.outputs)) else: if len(shape) != 2 or shape[1] != self._num_outputs: raise ValueError( "Expected output shape of [None, {}], got {}".format( self._num_outputs, shape))
@DeveloperAPI def flatten(obs, framework): """Flatten the given tensor.""" if framework == "tf": return tf.layers.flatten(obs) elif framework == "torch": assert torch is not None return torch.flatten(obs, start_dim=1) else: raise NotImplementedError("flatten", framework) @DeveloperAPI def restore_original_dimensions(obs, obs_space, tensorlib=tf): """Unpacks Dict and Tuple space observations into their original form. This is needed since we flatten Dict and Tuple observations in transit. Before sending them to the model though, we should unflatten them into Dicts or Tuples of tensors. Arguments: obs: The flattened observation tensor. obs_space: The flattened obs space. If this has the `original_space` attribute, we will unflatten the tensor to that shape. tensorlib: The library used to unflatten (reshape) the array/tensor. Returns: single tensor or dict / tuple of tensors matching the original observation space. """ if hasattr(obs_space, "original_space"): if tensorlib == "tf": tensorlib = tf elif tensorlib == "torch": assert torch is not None tensorlib = torch return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib) else: return obs # Cache of preprocessors, for if the user is calling unpack obs often. _cache = {} def _unpack_obs(obs, space, tensorlib=tf): """Unpack a flattened Dict or Tuple observation array/tensor. Arguments: obs: The flattened observation tensor space: The original space prior to flattening tensorlib: The library used to unflatten (reshape) the array/tensor """ if (isinstance(space, gym.spaces.Dict) or isinstance(space, gym.spaces.Tuple)): if id(space) in _cache: prep = _cache[id(space)] else: prep = get_preprocessor(space)(space) # Make an attempt to cache the result, if enough space left. if len(_cache) < 999: _cache[id(space)] = prep if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]: raise ValueError( "Expected flattened obs shape of [None, {}], got {}".format( prep.shape[0], obs.shape)) assert len(prep.preprocessors) == len(space.spaces), \ (len(prep.preprocessors) == len(space.spaces)) offset = 0 if isinstance(space, gym.spaces.Tuple): u = [] for p, v in zip(prep.preprocessors, space.spaces): obs_slice = obs[:, offset:offset + p.size] offset += p.size u.append( _unpack_obs( tensorlib.reshape(obs_slice, [-1] + list(p.shape)), v, tensorlib=tensorlib)) else: u = OrderedDict() for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): obs_slice = obs[:, offset:offset + p.size] offset += p.size u[k] = _unpack_obs( tensorlib.reshape(obs_slice, [-1] + list(p.shape)), v, tensorlib=tensorlib) return u else: return obs