Deep Learning Framework (tf vs torch) Utilities#

ray.rllib.utils.framework.try_import_jax(error: bool = False)[source]#

Tries importing JAX and FLAX and returns both modules (or Nones).

Parameters

error – Whether to raise an error if JAX/FLAX cannot be imported.

Returns

Tuple containing the jax- and the flax modules.

Raises

ImportError – If error=True and JAX is not installed.

ray.rllib.utils.framework.try_import_tf(error: bool = False)[source]#

Tries importing tf and returns the module (or None).

Parameters

error – Whether to raise an error if tf cannot be imported.

Returns

Tuple containing 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x). 2) tf module (resulting from import tensorflow). Either tf1.x or 2.x. 3) The actually installed tf version as int: 1 or 2.

Raises

ImportError – If error=True and tf is not installed.

ray.rllib.utils.framework.tf_function(tf_module)[source]#

Conditional decorator for @tf.function.

Use @tf_function(tf) instead to avoid errors if tf is not installed.

ray.rllib.utils.framework.try_import_tfp(error: bool = False)[source]#

Tries importing tfp and returns the module (or None).

Parameters

error – Whether to raise an error if tfp cannot be imported.

Returns

The tfp module.

Raises

ImportError – If error=True and tfp is not installed.

ray.rllib.utils.framework.try_import_torch(error: bool = False)[source]#

Tries importing torch and returns the module (or None).

Parameters

error – Whether to raise an error if torch cannot be imported.

Returns

Tuple consisting of the torch- AND torch.nn modules.

Raises

ImportError – If error=True and PyTorch is not installed.

ray.rllib.utils.framework.get_variable(value: Any, framework: str = 'tf', trainable: bool = False, tf_name: str = 'unnamed-variable', torch_tensor: bool = False, device: Optional[str] = None, shape: Optional[Union[Tuple[int], List[int]]] = None, dtype: Optional[Union[numpy.array, tf.Tensor, torch.Tensor]] = None) Any[source]#

Creates a tf variable, a torch tensor, or a python primitive.

Parameters
  • value – The initial value to use. In the non-tf case, this will be returned as is. In the tf case, this could be a tf-Initializer object.

  • framework – One of “tf”, “torch”, or None.

  • trainable – Whether the generated variable should be trainable (tf)/require_grad (torch) or not (default: False).

  • tf_name – For framework=”tf”: An optional name for the tf.Variable.

  • torch_tensor – For framework=”torch”: Whether to actually create a torch.tensor, or just a python value (default).

  • device – An optional torch device to use for the created torch tensor.

  • shape – An optional shape to use iff value does not have any (e.g. if it’s an initializer w/o explicit value).

  • dtype – An optional dtype to use iff value does not have any (e.g. if it’s an initializer w/o explicit value). This should always be a numpy dtype (e.g. np.float32, np.int64).

Returns

A framework-specific variable (tf.Variable, torch.tensor, or python primitive).