Source code for ray.rllib.utils.framework

import logging
import numpy as np
import os
import sys
from typing import Any, Optional

import tree  # pip install dm_tree

from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import (
    NetworkType,
    TensorShape,
    TensorStructType,
    TensorType,
)

logger = logging.getLogger(__name__)


@PublicAPI
def convert_to_tensor(
    data: TensorStructType,
    framework: str,
    device: Optional[str] = None,
):
    """Converts any nested numpy struct into framework-specific tensors.

    Args:
        data: The input data (numpy) to convert to framework-specific tensors.
        framework: The framework to convert to. Only "torch" and "tf2" allowed.
        device: An optional device name (for torch only).

    Returns:
        The converted tensor struct matching the input data.
    """
    if framework == "torch":
        from ray.rllib.utils.torch_utils import convert_to_torch_tensor

        return convert_to_torch_tensor(data, device=device)
    elif framework == "tf2":
        _, tf, _ = try_import_tf()

        return tree.map_structure(lambda s: tf.convert_to_tensor(s), data)
    raise NotImplementedError(
        f"framework={framework} not supported in `convert_to_tensor()`!"
    )


@PublicAPI
def try_import_jax(error: bool = False):
    """Tries importing JAX and FLAX and returns both modules (or Nones).

    Args:
        error: Whether to raise an error if JAX/FLAX cannot be imported.

    Returns:
        Tuple containing the jax- and the flax modules.

    Raises:
        ImportError: If error=True and JAX is not installed.
    """
    if "RLLIB_TEST_NO_JAX_IMPORT" in os.environ:
        logger.warning("Not importing JAX for test purposes.")
        return None, None

    try:
        import jax
        import flax
    except ImportError:
        if error:
            raise ImportError(
                "Could not import JAX! RLlib requires you to "
                "install at least one deep-learning framework: "
                "`pip install [torch|tensorflow|jax]`."
            )
        return None, None

    return jax, flax


[docs]@PublicAPI def try_import_tf(error: bool = False): """Tries importing tf and returns the module (or None). Args: error: Whether to raise an error if tf cannot be imported. Returns: Tuple containing 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). 2) tf module (resulting from `import tensorflow`). Either tf1.x or 2.x. 3) The actually installed tf version as int: 1 or 2. Raises: ImportError: If error=True and tf is not installed. """ tf_stub = _TFStub() # Make sure, these are reset after each test case # that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"] if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: logger.warning("Not importing TensorFlow for test purposes") return None, tf_stub, None if "TF_CPP_MIN_LOG_LEVEL" not in os.environ: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Try to reuse already imported tf module. This will avoid going through # the initial import steps below and thereby switching off v2_behavior # (switching off v2 behavior twice breaks all-framework tests for eager). was_imported = False if "tensorflow" in sys.modules: tf_module = sys.modules["tensorflow"] was_imported = True else: try: import tensorflow as tf_module except ImportError: if error: raise ImportError( "Could not import TensorFlow! RLlib requires you to " "install at least one deep-learning framework: " "`pip install [torch|tensorflow|jax]`." ) return None, tf_stub, None # Try "reducing" tf to tf.compat.v1. try: tf1_module = tf_module.compat.v1 tf1_module.logging.set_verbosity(tf1_module.logging.ERROR) if not was_imported: tf1_module.disable_v2_behavior() tf1_module.enable_resource_variables() tf1_module.logging.set_verbosity(tf1_module.logging.WARN) # No compat.v1 -> return tf as is. except AttributeError: tf1_module = tf_module if not hasattr(tf_module, "__version__"): version = 1 # sphinx doc gen else: version = 2 if "2." in tf_module.__version__[:2] else 1 return tf1_module, tf_module, version
# Fake module for tf. class _TFStub: def __init__(self) -> None: self.keras = _KerasStub() def __bool__(self): # if tf should return False return False # Fake module for tf.keras. class _KerasStub: def __init__(self) -> None: self.Model = _FakeTfClassStub # Fake classes under keras (e.g for tf.keras.Model) class _FakeTfClassStub: def __init__(self, *a, **kw): raise ImportError("Could not import `tensorflow`. Try pip install tensorflow.") @DeveloperAPI def tf_function(tf_module): """Conditional decorator for @tf.function. Use @tf_function(tf) instead to avoid errors if tf is not installed.""" # The actual decorator to use (pass in `tf` (which could be None)). def decorator(func): # If tf not installed -> return function as is (won't be used anyways). if tf_module is None or tf_module.executing_eagerly(): return func # If tf installed, return @tf.function-decorated function. return tf_module.function(func) return decorator
[docs]@PublicAPI def try_import_tfp(error: bool = False): """Tries importing tfp and returns the module (or None). Args: error: Whether to raise an error if tfp cannot be imported. Returns: The tfp module. Raises: ImportError: If error=True and tfp is not installed. """ if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: logger.warning("Not importing TensorFlow Probability for test purposes.") return None try: import tensorflow_probability as tfp return tfp except ImportError as e: if error: raise e return None
# Fake module for torch.nn. class _NNStub: def __init__(self, *a, **kw): # Fake nn.functional module within torch.nn. self.functional = None self.Module = _FakeTorchClassStub self.parallel = _ParallelStub() # Fake class for e.g. torch.nn.Module to allow it to be inherited from. class _FakeTorchClassStub: def __init__(self, *a, **kw): raise ImportError("Could not import `torch`. Try pip install torch.") class _ParallelStub: def __init__(self, *a, **kw): self.DataParallel = _FakeTorchClassStub self.DistributedDataParallel = _FakeTorchClassStub
[docs]@PublicAPI def try_import_torch(error: bool = False): """Tries importing torch and returns the module (or None). Args: error: Whether to raise an error if torch cannot be imported. Returns: Tuple consisting of the torch- AND torch.nn modules. Raises: ImportError: If error=True and PyTorch is not installed. """ if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ: logger.warning("Not importing PyTorch for test purposes.") return _torch_stubs() try: import torch import torch.nn as nn return torch, nn except ImportError: if error: raise ImportError( "Could not import PyTorch! RLlib requires you to " "install at least one deep-learning framework: " "`pip install [torch|tensorflow|jax]`." ) return _torch_stubs()
def _torch_stubs(): nn = _NNStub() return None, nn @PublicAPI def update_target_network( *, main_net: NetworkType, target_net: NetworkType, tau: float, framework: str = "torch", ) -> None: """Updates a target network (from a "main" network) using Polyak averaging. new_target_net_weight = ( tau * main_net_weight + (1.0 - tau) * current_target_net_weight ) Args: main_net: The nn.Module to update from. target_net: The target network to update. tau: The tau value to use in the Polyak averaging formula. framework: The framework specifier. Must be "torch" or "tf2". """ assert framework in ["torch", "tf2"] if framework == "torch": from ray.rllib.utils.torch_utils import update_target_network as _update_target else: from ray.rllib.utils.tf_utils import update_target_network as _update_target _update_target(main_net=main_net, target_net=target_net, tau=tau) @DeveloperAPI def get_variable( value: Any, framework: str = "tf", trainable: bool = False, tf_name: str = "unnamed-variable", torch_tensor: bool = False, device: Optional[str] = None, shape: Optional[TensorShape] = None, dtype: Optional[TensorType] = None, ) -> Any: """Creates a tf variable, a torch tensor, or a python primitive. Args: value: The initial value to use. In the non-tf case, this will be returned as is. In the tf case, this could be a tf-Initializer object. framework: One of "tf", "torch", or None. trainable: Whether the generated variable should be trainable (tf)/require_grad (torch) or not (default: False). tf_name: For framework="tf": An optional name for the tf.Variable. torch_tensor: For framework="torch": Whether to actually create a torch.tensor, or just a python value (default). device: An optional torch device to use for the created torch tensor. shape: An optional shape to use iff `value` does not have any (e.g. if it's an initializer w/o explicit value). dtype: An optional dtype to use iff `value` does not have any (e.g. if it's an initializer w/o explicit value). This should always be a numpy dtype (e.g. np.float32, np.int64). Returns: A framework-specific variable (tf.Variable, torch.tensor, or python primitive). """ if framework in ["tf2", "tf"]: import tensorflow as tf dtype = dtype or getattr( value, "dtype", tf.float32 if isinstance(value, float) else tf.int32 if isinstance(value, int) else None, ) return tf.compat.v1.get_variable( tf_name, initializer=value, dtype=dtype, trainable=trainable, **({} if shape is None else {"shape": shape}), ) elif framework == "torch" and torch_tensor is True: torch, _ = try_import_torch() if not isinstance(value, np.ndarray): value = np.array(value) var_ = torch.from_numpy(value) if dtype in [torch.float32, np.float32]: var_ = var_.float() elif dtype in [torch.int32, np.int32]: var_ = var_.int() elif dtype in [torch.float64, np.float64]: var_ = var_.double() if device: var_ = var_.to(device) var_.requires_grad = trainable return var_ # torch or None: Return python primitive. return value @Deprecated( old="rllib/utils/framework.py::get_activation_fn", new="rllib/models/utils.py::get_activation_fn", error=True, ) def get_activation_fn(name: Optional[str] = None, framework: str = "tf"): pass