import logging
import warnings
from copy import copy
from inspect import signature
from math import isclose
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import numpy as np
# Backwards compatibility
from ray.util.annotations import DeveloperAPI, PublicAPI, RayDeprecationWarning
try:
    # Added in numpy>=1.17 but we require numpy>=1.16
    np_random_generator = np.random.Generator
    LEGACY_RNG = False
except AttributeError:
    class np_random_generator:
        pass
    LEGACY_RNG = True
logger = logging.getLogger(__name__)
_MISSING = object()  # Sentinel for missing parameters.
def _warn_for_base() -> None:
    warnings.warn(
        (
            "The `base` argument is deprecated. "
            "Please remove it as it is not actually needed in this method."
        ),
        RayDeprecationWarning,
        stacklevel=2,
    )
class _BackwardsCompatibleNumpyRng:
    """Thin wrapper to ensure backwards compatibility between
    new and old numpy randomness generators.
    """
    _rng = None
    def __init__(
        self,
        generator_or_seed: Optional[
            Union["np_random_generator", np.random.RandomState, int]
        ] = None,
    ):
        if generator_or_seed is None or isinstance(
            generator_or_seed, (np.random.RandomState, np_random_generator)
        ):
            self._rng = generator_or_seed
        elif LEGACY_RNG:
            self._rng = np.random.RandomState(generator_or_seed)
        else:
            self._rng = np.random.default_rng(generator_or_seed)
    @property
    def legacy_rng(self) -> bool:
        return not isinstance(self._rng, np_random_generator)
    @property
    def rng(self):
        # don't set self._rng to np.random to avoid picking issues
        return self._rng if self._rng is not None else np.random
    def __getattr__(self, name: str) -> Any:
        # https://numpy.org/doc/stable/reference/random/new-or-different.html
        if self.legacy_rng:
            if name == "integers":
                name = "randint"
            elif name == "random":
                name = "rand"
        return getattr(self.rng, name)
RandomState = Union[
    None, _BackwardsCompatibleNumpyRng, np_random_generator, np.random.RandomState, int
]
@DeveloperAPI
class Domain:
    """Base class to specify a type and valid range to sample parameters from.
    This base class is implemented by parameter spaces, like float ranges
    (``Float``), integer ranges (``Integer``), or categorical variables
    (``Categorical``). The ``Domain`` object contains information about
    valid values (e.g. minimum and maximum values), and exposes methods that
    allow specification of specific samplers (e.g. ``uniform()`` or
    ``loguniform()``).
    """
    sampler = None
    default_sampler_cls = None
    def cast(self, value):
        """Cast value to domain type"""
        return value
    def set_sampler(self, sampler, allow_override=False):
        if self.sampler and not allow_override:
            raise ValueError(
                "You can only choose one sampler for parameter "
                "domains. Existing sampler for parameter {}: "
                "{}. Tried to add {}".format(
                    self.__class__.__name__, self.sampler, sampler
                )
            )
        self.sampler = sampler
    def get_sampler(self):
        sampler = self.sampler
        if not sampler:
            sampler = self.default_sampler_cls()
        return sampler
    def sample(
        self,
        config: Optional[Union[List[Dict], Dict]] = None,
        size: int = 1,
        random_state: "RandomState" = None,
    ):
        if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
            random_state = _BackwardsCompatibleNumpyRng(random_state)
        sampler = self.get_sampler()
        return sampler.sample(self, config=config, size=size, random_state=random_state)
    def is_grid(self):
        return isinstance(self.sampler, Grid)
    def is_function(self):
        return False
    def is_valid(self, value: Any):
        """Returns True if `value` is a valid value in this domain."""
        raise NotImplementedError
    @property
    def domain_str(self):
        return "(unknown)"
@DeveloperAPI
class Sampler:
    def sample(
        self,
        domain: Domain,
        config: Optional[Union[List[Dict], Dict]] = None,
        size: int = 1,
        random_state: "RandomState" = None,
    ):
        raise NotImplementedError
@DeveloperAPI
class BaseSampler(Sampler):
    def __str__(self):
        return "Base"
@DeveloperAPI
class Uniform(Sampler):
    def __str__(self):
        return "Uniform"
@DeveloperAPI
class LogUniform(Sampler):
    def __init__(self, base: object = _MISSING):
        if base is not _MISSING:
            _warn_for_base()
    def __str__(self):
        return "LogUniform"
@DeveloperAPI
class Normal(Sampler):
    def __init__(self, mean: float = 0.0, sd: float = 0.0):
        self.mean = mean
        self.sd = sd
        assert self.sd > 0, "SD has to be strictly greater than 0"
    def __str__(self):
        return "Normal"
@DeveloperAPI
class Grid(Sampler):
    """Dummy sampler used for grid search"""
    def sample(
        self,
        domain: Domain,
        config: Optional[Union[List[Dict], Dict]] = None,
        size: int = 1,
        random_state: "RandomState" = None,
    ):
        return RuntimeError("Do not call `sample()` on grid.")
@DeveloperAPI
class Float(Domain):
    class _Uniform(Uniform):
        def sample(
            self,
            domain: "Float",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            assert domain.lower > float("-inf"), "Uniform needs a lower bound"
            assert domain.upper < float("inf"), "Uniform needs a upper bound"
            items = random_state.uniform(domain.lower, domain.upper, size=size)
            return items if len(items) > 1 else domain.cast(items[0])
    class _LogUniform(LogUniform):
        def sample(
            self,
            domain: "Float",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
            assert (
                0 < domain.upper < float("inf")
            ), "LogUniform needs a upper bound greater than 0"
            logmin = np.log(domain.lower)
            logmax = np.log(domain.upper)
            items = np.exp(random_state.uniform(logmin, logmax, size=size))
            return items if len(items) > 1 else domain.cast(items[0])
    class _Normal(Normal):
        def sample(
            self,
            domain: "Float",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            assert not domain.lower or domain.lower == float(
                "-inf"
            ), "Normal sampling does not allow a lower value bound."
            assert not domain.upper or domain.upper == float(
                "inf"
            ), "Normal sampling does not allow a upper value bound."
            items = random_state.normal(self.mean, self.sd, size=size)
            return items if len(items) > 1 else domain.cast(items[0])
    default_sampler_cls = _Uniform
    def __init__(self, lower: Optional[float], upper: Optional[float]):
        # Need to explicitly check for None
        self.lower = lower if lower is not None else float("-inf")
        self.upper = upper if upper is not None else float("inf")
    def cast(self, value):
        return float(value)
    def uniform(self):
        if not self.lower > float("-inf"):
            raise ValueError(
                "Uniform requires a lower bound. Make sure to set the "
                "`lower` parameter of `Float()`."
            )
        if not self.upper < float("inf"):
            raise ValueError(
                "Uniform requires a upper bound. Make sure to set the "
                "`upper` parameter of `Float()`."
            )
        new = copy(self)
        new.set_sampler(self._Uniform())
        return new
    def loguniform(self, base: object = _MISSING):
        if base is not _MISSING:
            _warn_for_base()
        if not self.lower > 0:
            raise ValueError(
                "LogUniform requires a lower bound greater than 0."
                f"Got: {self.lower}. Did you pass a variable that has "
                "been log-transformed? If so, pass the non-transformed value "
                "instead."
            )
        if not 0 < self.upper < float("inf"):
            raise ValueError(
                "LogUniform requires a upper bound greater than 0. "
                f"Got: {self.lower}. Did you pass a variable that has "
                "been log-transformed? If so, pass the non-transformed value "
                "instead."
            )
        new = copy(self)
        new.set_sampler(self._LogUniform())
        return new
    def normal(self, mean=0.0, sd=1.0):
        new = copy(self)
        new.set_sampler(self._Normal(mean, sd))
        return new
    def quantized(self, q: float):
        if self.lower > float("-inf") and not isclose(
            self.lower / q, round(self.lower / q)
        ):
            raise ValueError(
                f"Your lower variable bound {self.lower} is not divisible by "
                f"quantization factor {q}."
            )
        if self.upper < float("inf") and not isclose(
            self.upper / q, round(self.upper / q)
        ):
            raise ValueError(
                f"Your upper variable bound {self.upper} is not divisible by "
                f"quantization factor {q}."
            )
        new = copy(self)
        new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
        return new
    def is_valid(self, value: float):
        return self.lower <= value <= self.upper
    @property
    def domain_str(self):
        return f"({self.lower}, {self.upper})"
@DeveloperAPI
class Integer(Domain):
    class _Uniform(Uniform):
        def sample(
            self,
            domain: "Integer",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            items = random_state.integers(domain.lower, domain.upper, size=size)
            return items if len(items) > 1 else domain.cast(items[0])
    class _LogUniform(LogUniform):
        def sample(
            self,
            domain: "Integer",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
            assert (
                0 < domain.upper < float("inf")
            ), "LogUniform needs a upper bound greater than 0"
            logmin = np.log(domain.lower)
            logmax = np.log(domain.upper)
            items = np.exp(random_state.uniform(logmin, logmax, size=size))
            items = np.floor(items).astype(int)
            return items if len(items) > 1 else domain.cast(items[0])
    default_sampler_cls = _Uniform
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper
    def cast(self, value):
        return int(value)
    def quantized(self, q: int):
        new = copy(self)
        new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
        return new
    def uniform(self):
        new = copy(self)
        new.set_sampler(self._Uniform())
        return new
    def loguniform(self, base: object = _MISSING):
        if base is not _MISSING:
            _warn_for_base()
        if not self.lower > 0:
            raise ValueError(
                "LogUniform requires a lower bound greater than 0."
                f"Got: {self.lower}. Did you pass a variable that has "
                "been log-transformed? If so, pass the non-transformed value "
                "instead."
            )
        if not 0 < self.upper < float("inf"):
            raise ValueError(
                "LogUniform requires a upper bound greater than 0. "
                f"Got: {self.lower}. Did you pass a variable that has "
                "been log-transformed? If so, pass the non-transformed value "
                "instead."
            )
        new = copy(self)
        new.set_sampler(self._LogUniform())
        return new
    def is_valid(self, value: int):
        return self.lower <= value <= self.upper
    @property
    def domain_str(self):
        return f"({self.lower}, {self.upper})"
@DeveloperAPI
class Categorical(Domain):
    class _Uniform(Uniform):
        def sample(
            self,
            domain: "Categorical",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            # do not use .choice() directly on domain.categories
            # as that will coerce them to a single dtype
            indices = random_state.choice(
                np.arange(0, len(domain.categories)), size=size
            )
            items = [domain.categories[index] for index in indices]
            return items if len(items) > 1 else domain.cast(items[0])
    default_sampler_cls = _Uniform
    def __init__(self, categories: Sequence):
        self.categories = list(categories)
    def uniform(self):
        new = copy(self)
        new.set_sampler(self._Uniform())
        return new
    def grid(self):
        new = copy(self)
        new.set_sampler(Grid())
        return new
    def __len__(self):
        return len(self.categories)
    def __getitem__(self, item):
        return self.categories[item]
    def is_valid(self, value: Any):
        return value in self.categories
    @property
    def domain_str(self):
        return f"{self.categories}"
@DeveloperAPI
class Function(Domain):
    class _CallSampler(BaseSampler):
        def __try_fn(self, domain: "Function", config: Dict[str, Any]):
            try:
                return domain.func(config)
            except (AttributeError, KeyError):
                from ray.tune.search.variant_generator import _UnresolvedAccessGuard
                r = domain.func(_UnresolvedAccessGuard({"config": config}))
                logger.warning(
                    "sample_from functions that take a spec dict are "
                    "deprecated. Please update your function to work with "
                    "the config dict directly."
                )
                return r
        def sample(
            self,
            domain: "Function",
            config: Optional[Union[List[Dict], Dict]] = None,
            size: int = 1,
            random_state: "RandomState" = None,
        ):
            if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
                random_state = _BackwardsCompatibleNumpyRng(random_state)
            if domain.pass_config:
                items = [
                    (
                        self.__try_fn(domain, config[i])
                        if isinstance(config, list)
                        else self.__try_fn(domain, config)
                    )
                    for i in range(size)
                ]
            else:
                items = [domain.func() for i in range(size)]
            return items if len(items) > 1 else domain.cast(items[0])
    default_sampler_cls = _CallSampler
    def __init__(self, func: Callable):
        sig = signature(func)
        pass_config = True  # whether we should pass `config` when calling `func`
        try:
            sig.bind({})
        except TypeError:
            pass_config = False
        if not pass_config:
            try:
                sig.bind()
            except TypeError as exc:
                raise ValueError(
                    "The function passed to a `Function` parameter must be "
                    "callable with either 0 or 1 parameters."
                ) from exc
        self.pass_config = pass_config
        self.func = func
    def is_function(self):
        return True
    def is_valid(self, value: Any):
        return True  # This is user-defined, so lets not assume anything
    @property
    def domain_str(self):
        return f"{self.func}()"
@DeveloperAPI
class Quantized(Sampler):
    def __init__(self, sampler: Sampler, q: Union[float, int]):
        self.sampler = sampler
        self.q = q
        assert self.sampler, "Quantized() expects a sampler instance"
    def get_sampler(self):
        return self.sampler
    def sample(
        self,
        domain: Domain,
        config: Optional[Union[List[Dict], Dict]] = None,
        size: int = 1,
        random_state: "RandomState" = None,
    ):
        if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
            random_state = _BackwardsCompatibleNumpyRng(random_state)
        if self.q == 1:
            return self.sampler.sample(domain, config, size, random_state=random_state)
        quantized_domain = copy(domain)
        quantized_domain.lower = np.ceil(domain.lower / self.q) * self.q
        quantized_domain.upper = np.floor(domain.upper / self.q) * self.q
        values = self.sampler.sample(
            quantized_domain, config, size, random_state=random_state
        )
        quantized = np.round(np.divide(values, self.q)) * self.q
        if not isinstance(quantized, np.ndarray):
            return domain.cast(quantized)
        return list(quantized)
[docs]
@PublicAPI
def sample_from(func: Callable[[Dict], Any]):
    """Specify that tune should sample configuration values from this function.
    Arguments:
        func: An callable function to draw a sample from.
    """
    return Function(func) 
[docs]
@PublicAPI
def choice(categories: Sequence):
    """Sample a categorical value.
    Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
    ``np.random.choice([1, 2])``
    """
    return Categorical(categories).uniform() 
[docs]
@PublicAPI
def randint(lower: int, upper: int):
    """Sample an integer value uniformly between ``lower`` and ``upper``.
    ``lower`` is inclusive, ``upper`` is exclusive.
    Sampling from ``tune.randint(10)`` is equivalent to sampling from
    ``np.random.randint(10)``
    .. versionchanged:: 1.5.0
        When converting Ray Tune configs to searcher-specific search spaces,
        the lower and upper limits are adjusted to keep compatibility with
        the bounds stated in the docstring above.
    """
    return Integer(lower, upper).uniform() 
[docs]
@PublicAPI
def lograndint(lower: int, upper: int, base: object = _MISSING):
    """Sample an integer value log-uniformly between ``lower`` and ``upper``.
    ``lower`` is inclusive, ``upper`` is exclusive.
    .. versionchanged:: 1.5.0
        When converting Ray Tune configs to searcher-specific search spaces,
        the lower and upper limits are adjusted to keep compatibility with
        the bounds stated in the docstring above.
    """
    if base is not _MISSING:
        _warn_for_base()
    return Integer(lower, upper).loguniform() 
[docs]
@PublicAPI
def qrandint(lower: int, upper: int, q: int = 1):
    """Sample an integer value uniformly between ``lower`` and ``upper``.
    ``lower`` is inclusive, ``upper`` is also inclusive (!).
    The value will be quantized, i.e. rounded to an integer increment of ``q``.
    Quantization makes the upper bound inclusive.
    .. versionchanged:: 1.5.0
        When converting Ray Tune configs to searcher-specific search spaces,
        the lower and upper limits are adjusted to keep compatibility with
        the bounds stated in the docstring above.
    """
    return Integer(lower, upper).uniform().quantized(q) 
[docs]
@PublicAPI
def qlograndint(lower: int, upper: int, q: int, base: object = _MISSING):
    """Sample an integer value log-uniformly between ``lower`` and ``upper``.
    ``lower`` is inclusive, ``upper`` is also inclusive (!).
    The value will be quantized, i.e. rounded to an integer increment of ``q``.
    Quantization makes the upper bound inclusive.
    .. versionchanged:: 1.5.0
        When converting Ray Tune configs to searcher-specific search spaces,
        the lower and upper limits are adjusted to keep compatibility with
        the bounds stated in the docstring above.
    """
    if base is not _MISSING:
        _warn_for_base()
    return Integer(lower, upper).loguniform().quantized(q) 
[docs]
@PublicAPI
def randn(mean: float = 0.0, sd: float = 1.0):
    """Sample a float value normally with ``mean`` and ``sd``.
    Args:
        mean: Mean of the normal distribution. Defaults to 0.
        sd: SD of the normal distribution. Defaults to 1.
    """
    return Float(None, None).normal(mean, sd) 
[docs]
@PublicAPI
def qrandn(mean: float, sd: float, q: float):
    """Sample a float value normally with ``mean`` and ``sd``.
    The value will be quantized, i.e. rounded to an integer increment of ``q``.
    Args:
        mean: Mean of the normal distribution.
        sd: SD of the normal distribution.
        q: Quantization number. The result will be rounded to an
            integer increment of this value.
    """
    return Float(None, None).normal(mean, sd).quantized(q)