Source code for ray.tune.suggest.basic_variant

import copy
import glob
import itertools
import os
import uuid
from typing import Dict, List, Optional, Union
import warnings

from ray.tune.error import TuneError
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.variant_generator import (
    count_variants, count_spec_samples, generate_variants, format_vars,
    flatten_resolved_vars, get_preset_variants)
from import SearchAlgorithm
from ray.tune.utils.util import atomic_save, load_newest_checkpoint


class _VariantIterator:
    """Iterates over generated variants from the search space.

    This object also toggles between lazy evaluation and
    eager evaluation of samples. If lazy evaluation is enabled,
    this object cannot be serialized.

    def __init__(self, iterable, lazy_eval=False):
        self.lazy_eval = lazy_eval
        self.iterable = iterable
        self._has_next = True
        if lazy_eval:
            self.iterable = list(iterable)
            self._has_next = bool(self.iterable)

    def _load_value(self):
            self.next_value = next(self.iterable)
        except StopIteration:
            self._has_next = False

    def has_next(self):
        return self._has_next

    def __next__(self):
        if self.lazy_eval:
            current_value = self.next_value
            return current_value
        current_value = self.iterable.pop(0)
        self._has_next = bool(self.iterable)
        return current_value

class _TrialIterator:
    """Generates trials from the spec.

        uuid_prefix (str): Used in creating the trial name.
        num_samples (int): Number of samples from distribution
             (same as
        unresolved_spec (dict): Experiment specification
            that might have unresolved distributions.
        output_path (str): A specific output path within the local_dir.
        points_to_evaluate (list): Same as
        lazy_eval (bool): Whether variants should be generated
            lazily or eagerly. This is toggled depending
            on the size of the grid search.
        start (int): index at which to start counting trials.

    def __init__(self,
                 uuid_prefix: str,
                 num_samples: int,
                 unresolved_spec: dict,
                 output_path: str = "",
                 points_to_evaluate: Optional[List] = None,
                 lazy_eval: bool = False,
                 start: int = 0):
        self.parser = make_parser()
        self.num_samples = num_samples
        self.uuid_prefix = uuid_prefix
        self.num_samples_left = num_samples
        self.unresolved_spec = unresolved_spec
        self.output_path = output_path
        self.points_to_evaluate = points_to_evaluate or []
        self.num_points_to_evaluate = len(self.points_to_evaluate)
        self.counter = start
        self.lazy_eval = lazy_eval
        self.variants = None

    def create_trial(self, resolved_vars, spec):
        trial_id = self.uuid_prefix + ("%05d" % self.counter)
        experiment_tag = str(self.counter)
        # Always append resolved vars to experiment tag?
        if resolved_vars:
            experiment_tag += "_{}".format(format_vars(resolved_vars))
        self.counter += 1
        return create_trial_from_spec(

    def __next__(self):
        """Generates Trial objects with the variant generation process.

        Uses a fixed point iteration to resolve variants. All trials
        should be able to be generated at once.

        See also: `ray.tune.suggest.variant_generator`.

            Trial object

        if "run" not in self.unresolved_spec:
            raise TuneError("Must specify `run` in {}".format(

        if self.variants and self.variants.has_next():
            # This block will be skipped upon instantiation.
            # `variants` will be set later after the first loop.
            resolved_vars, spec = next(self.variants)
            return self.create_trial(resolved_vars, spec)

        if self.points_to_evaluate:
            config = self.points_to_evaluate.pop(0)
            self.num_samples_left -= 1
            self.variants = _VariantIterator(
                get_preset_variants(self.unresolved_spec, config),
            resolved_vars, spec = next(self.variants)
            return self.create_trial(resolved_vars, spec)
        elif self.num_samples_left > 0:
            self.variants = _VariantIterator(
            self.num_samples_left -= 1
            resolved_vars, spec = next(self.variants)
            return self.create_trial(resolved_vars, spec)
            raise StopIteration

    def __iter__(self):
        return self

[docs]class BasicVariantGenerator(SearchAlgorithm): """Uses Tune's variant generation for resolving variables. This is the default search algorithm used if no other search algorithm is specified. Args: points_to_evaluate (list): Initial parameter suggestions to be run first. This is for when you already have some good parameters you want to run first to help the algorithm make better suggestions for future parameters. Needs to be a list of dicts containing the configurations. Example: .. code-block:: python from ray import tune # This will automatically use the `BasicVariantGenerator` lambda config: config["a"] + config["b"], config={ "a": tune.grid_search([1, 2]), "b": tune.randint(0, 3) }, num_samples=4) In the example above, 8 trials will be generated: For each sample (``4``), each of the grid search variants for ``a`` will be sampled once. The ``b`` parameter will be sampled randomly. The generator accepts a pre-set list of points that should be evaluated. The points will replace the first samples of each experiment passed to the ``BasicVariantGenerator``. Each point will replace one sample of the specified ``num_samples``. If grid search variables are overwritten with the values specified in the presets, the number of samples will thus be reduced. Example: .. code-block:: python from ray import tune from ray.tune.suggest.basic_variant import BasicVariantGenerator lambda config: config["a"] + config["b"], config={ "a": tune.grid_search([1, 2]), "b": tune.randint(0, 3) }, search_alg=BasicVariantGenerator(points_to_evaluate=[ {"a": 2, "b": 2}, {"a": 1}, {"b": 2} ]), num_samples=4) The example above will produce six trials via four samples: - The first sample will produce one trial with ``a=2`` and ``b=2``. - The second sample will produce one trial with ``a=1`` and ``b`` sampled randomly - The third sample will produce two trials, one for each grid search value of ``a``. It will be ``b=2`` for both of these trials. - The fourth sample will produce two trials, one for each grid search value of ``a``. ``b`` will be sampled randomly and independently for both of these trials. """ CKPT_FILE_TMPL = "basic-variant-state-{}.json" def __init__(self, points_to_evaluate: Optional[List[Dict]] = None): self._trial_generator = [] self._iterators = [] self._trial_iter = None self._finished = False self._points_to_evaluate = points_to_evaluate or [] # Unique prefix for all trials generated, e.g., trial ids start as # 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing. force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID") if force_test_uuid: self._uuid_prefix = force_test_uuid + "_" else: self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_" self._total_samples = 0 @property def total_samples(self): return self._total_samples def add_configurations( self, experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]): """Chains generator given experiment specifications. Arguments: experiments (Experiment | list | dict): Experiments to run. """ experiment_list = convert_to_experiment_list(experiments) for experiment in experiment_list: grid_vals = count_spec_samples(experiment.spec, num_samples=1) lazy_eval = grid_vals > SERIALIZATION_THRESHOLD if lazy_eval: warnings.warn( f"The number of pre-generated samples ({grid_vals}) " "exceeds the serialization threshold " f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is " "disabled. To fix this, reduce the number of " "dimensions/size of the provided grid search.") previous_samples = self._total_samples points_to_evaluate = copy.deepcopy(self._points_to_evaluate) self._total_samples += count_variants(experiment.spec, points_to_evaluate) iterator = _TrialIterator( uuid_prefix=self._uuid_prefix, num_samples=experiment.spec.get("num_samples", 1), unresolved_spec=experiment.spec, output_path=experiment.dir_name, points_to_evaluate=points_to_evaluate, lazy_eval=lazy_eval, start=previous_samples) self._iterators.append(iterator) self._trial_generator = itertools.chain(self._trial_generator, iterator) def next_trial(self): """Provides one Trial object to be queued into the TrialRunner. Returns: Trial: Returns a single trial. """ if not self._trial_iter: self._trial_iter = iter(self._trial_generator) try: return next(self._trial_iter) except StopIteration: self._trial_generator = [] self._trial_iter = None self.set_finished() return None def get_state(self): if any(iterator.lazy_eval for iterator in self._iterators): return False state = self.__dict__.copy() del state["_trial_generator"] return state def set_state(self, state): self.__dict__.update(state) for iterator in self._iterators: self._trial_generator = itertools.chain(self._trial_generator, iterator) def save_to_dir(self, dirpath, session_str): if any(iterator.lazy_eval for iterator in self._iterators): return False state_dict = self.get_state() atomic_save( state=state_dict, checkpoint_dir=dirpath, file_name=self.CKPT_FILE_TMPL.format(session_str), tmp_file_name=".tmp_generator") def has_checkpoint(self, dirpath: str): """Whether a checkpoint file exists within dirpath.""" return bool( glob.glob(os.path.join(dirpath, self.CKPT_FILE_TMPL.format("*")))) def restore_from_dir(self, dirpath: str): """Restores self + searcher + search wrappers from dirpath.""" state_dict = load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")) if not state_dict: raise RuntimeError( "Unable to find checkpoint in {}.".format(dirpath)) self.set_state(state_dict)