Source code for ray.rllib.utils.tensor_dtype

import numpy as np

from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.framework import try_import_torch, try_import_tf
from ray.util.annotations import PublicAPI

torch, _ = try_import_torch()
_, tf, _ = try_import_tf()


# Dict of NumPy dtype -> torch dtype
if torch:
    numpy_to_torch_dtype_dict = {
        np.bool_: torch.bool,
        np.uint8: torch.uint8,
        np.int8: torch.int8,
        np.int16: torch.int16,
        np.int32: torch.int32,
        np.int64: torch.int64,
        np.float16: torch.float16,
        np.float32: torch.float32,
        np.float64: torch.float64,
        np.complex64: torch.complex64,
        np.complex128: torch.complex128,
    }
else:
    numpy_to_torch_dtype_dict = {}

# Dict of NumPy dtype -> tf dtype
if tf:
    numpy_to_tf_dtype_dict = {
        np.bool_: tf.bool,
        np.uint8: tf.uint8,
        np.int8: tf.int8,
        np.int16: tf.int16,
        np.int32: tf.int32,
        np.int64: tf.int64,
        np.float16: tf.float16,
        np.float32: tf.float32,
        np.float64: tf.float64,
        np.complex64: tf.complex64,
        np.complex128: tf.complex128,
    }
else:
    numpy_to_tf_dtype_dict = {}

# Dict of torch dtype -> NumPy dtype
torch_to_numpy_dtype_dict = {
    value: key for (key, value) in numpy_to_torch_dtype_dict.items()
}
# Dict of tf dtype -> NumPy dtype
tf_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_tf_dtype_dict.items()}


[docs] @PublicAPI(stability="alpha") def get_np_dtype(x: TensorType) -> np.dtype: """Returns the NumPy dtype of the given tensor or array.""" if torch and isinstance(x, torch.Tensor): return torch_to_numpy_dtype_dict[x.dtype] if tf and isinstance(x, tf.Tensor): return tf_to_numpy_dtype_dict[x.dtype] elif isinstance(x, np.ndarray): return x.dtype else: raise TypeError("Unsupported type: {}".format(type(x)))