import contextlib
import gymnasium as gym
import re
from typing import Dict, List, Union
from ray.util import log_once
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType
tf1, tf, tfv = try_import_tf()
[docs]
@OldAPIStack
class TFModelV2(ModelV2):
"""TF version of ModelV2, which should contain a tf keras Model.
Note that this class by itself is not a valid model unless you
implement forward() in a subclass."""
[docs]
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
"""Initializes a TFModelV2 instance.
Here is an example implementation for a subclass
``MyModelClass(TFModelV2)``::
def __init__(self, *args, **kwargs):
super(MyModelClass, self).__init__(*args, **kwargs)
input_layer = tf.keras.layers.Input(...)
hidden_layer = tf.keras.layers.Dense(...)(input_layer)
output_layer = tf.keras.layers.Dense(...)(hidden_layer)
value_layer = tf.keras.layers.Dense(...)(hidden_layer)
self.base_model = tf.keras.Model(
input_layer, [output_layer, value_layer])
"""
super().__init__(
obs_space, action_space, num_outputs, model_config, name, framework="tf"
)
# Deprecated: TFModelV2 now automatically track their variables.
self.var_list = []
if tf1.executing_eagerly():
self.graph = None
else:
self.graph = tf1.get_default_graph()
[docs]
def context(self) -> contextlib.AbstractContextManager:
"""Returns a contextmanager for the current TF graph."""
if self.graph:
return self.graph.as_default()
else:
return ModelV2.context(self)
[docs]
def update_ops(self) -> List[TensorType]:
"""Return the list of update ops for this model.
For example, this should include any BatchNorm update ops."""
return []
[docs]
def register_variables(self, variables: List[TensorType]) -> None:
"""Register the given list of variables with this model."""
if log_once("deprecated_tfmodelv2_register_variables"):
deprecation_warning(old="TFModelV2.register_variables", error=False)
self.var_list.extend(variables)
[docs]
@override(ModelV2)
def variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
if as_dict:
# Old way using `register_variables`.
if self.var_list:
return {v.name: v for v in self.var_list}
# New way: Automatically determine the var tree.
else:
return self._find_sub_modules("", self.__dict__)
# Old way using `register_variables`.
if self.var_list:
return list(self.var_list)
# New way: Automatically determine the var tree.
else:
return list(self.variables(as_dict=True).values())
[docs]
@override(ModelV2)
def trainable_variables(
self, as_dict: bool = False
) -> Union[List[TensorType], Dict[str, TensorType]]:
if as_dict:
return {
k: v for k, v in self.variables(as_dict=True).items() if v.trainable
}
return [v for v in self.variables() if v.trainable]
@staticmethod
def _find_sub_modules(current_key, struct):
# Keras Model: key=k + "." + var-name (replace '/' by '.').
if isinstance(struct, tf.keras.models.Model) or isinstance(struct, tf.Module):
ret = {}
for var in struct.variables:
name = re.sub("/", ".", var.name)
key = current_key + "." + name
ret[key] = var
return ret
# Other TFModelV2: Include its vars into ours.
elif isinstance(struct, TFModelV2):
return {
current_key + "." + key: var
for key, var in struct.variables(as_dict=True).items()
}
# tf.Variable
elif isinstance(struct, tf.Variable):
return {current_key: struct}
# List/Tuple.
elif isinstance(struct, (tuple, list)):
ret = {}
for i, value in enumerate(struct):
sub_vars = TFModelV2._find_sub_modules(
current_key + "_{}".format(i), value
)
ret.update(sub_vars)
return ret
# Dict.
elif isinstance(struct, dict):
if current_key:
current_key += "_"
ret = {}
for key, value in struct.items():
sub_vars = TFModelV2._find_sub_modules(current_key + str(key), value)
ret.update(sub_vars)
return ret
return {}