Source code for ray.tune.search.concurrency_limiter

import copy
import logging
from typing import Dict, Optional, List

from ray.tune.search.searcher import Searcher
from ray.tune.search.util import _set_search_properties_backwards_compatible
from ray.util.annotations import PublicAPI


logger = logging.getLogger(__name__)


[docs]@PublicAPI class ConcurrencyLimiter(Searcher): """A wrapper algorithm for limiting the number of concurrent trials. Certain Searchers have their own internal logic for limiting the number of concurrent trials. If such a Searcher is passed to a ``ConcurrencyLimiter``, the ``max_concurrent`` of the ``ConcurrencyLimiter`` will override the ``max_concurrent`` value of the Searcher. The ``ConcurrencyLimiter`` will then let the Searcher's internal logic take over. Args: searcher: Searcher object that the ConcurrencyLimiter will manage. max_concurrent: Maximum concurrent samples from the underlying searcher. batch: Whether to wait for all concurrent samples to finish before updating the underlying searcher. Example: .. code-block:: python from ray.tune.search import ConcurrencyLimiter search_alg = HyperOptSearch(metric="accuracy") search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2) tuner = tune.Tuner( trainable, tune_config=tune.TuneConfig( search_alg=search_alg ), ) tuner.fit() """ def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False): assert type(max_concurrent) is int and max_concurrent > 0 self.searcher = searcher self.max_concurrent = max_concurrent self.batch = batch self.live_trials = set() self.num_unfinished_live_trials = 0 self.cached_results = {} self._limit_concurrency = True if not isinstance(searcher, Searcher): raise RuntimeError( f"The `ConcurrencyLimiter` only works with `Searcher` " f"objects (got {type(searcher)}). Please try to pass " f"`max_concurrent` to the search generator directly." ) self._set_searcher_max_concurrency() super(ConcurrencyLimiter, self).__init__( metric=self.searcher.metric, mode=self.searcher.mode ) def _set_searcher_max_concurrency(self): # If the searcher has special logic for handling max concurrency, # we do not do anything inside the ConcurrencyLimiter self._limit_concurrency = not self.searcher.set_max_concurrency( self.max_concurrent ) def set_max_concurrency(self, max_concurrent: int) -> bool: # Determine if this behavior is acceptable, or if it should # raise an exception. self.max_concurrent = max_concurrent return True def set_search_properties( self, metric: Optional[str], mode: Optional[str], config: Dict, **spec ) -> bool: self._set_searcher_max_concurrency() return _set_search_properties_backwards_compatible( self.searcher.set_search_properties, metric, mode, config, **spec ) def suggest(self, trial_id: str) -> Optional[Dict]: if not self._limit_concurrency: return self.searcher.suggest(trial_id) assert ( trial_id not in self.live_trials ), f"Trial ID {trial_id} must be unique: already found in set." if len(self.live_trials) >= self.max_concurrent: logger.debug( f"Not providing a suggestion for {trial_id} due to " "concurrency limit: %s/%s.", len(self.live_trials), self.max_concurrent, ) return suggestion = self.searcher.suggest(trial_id) if suggestion not in (None, Searcher.FINISHED): self.live_trials.add(trial_id) self.num_unfinished_live_trials += 1 return suggestion def on_trial_complete( self, trial_id: str, result: Optional[Dict] = None, error: bool = False ): if not self._limit_concurrency: return self.searcher.on_trial_complete(trial_id, result=result, error=error) if trial_id not in self.live_trials: return elif self.batch: self.cached_results[trial_id] = (result, error) self.num_unfinished_live_trials -= 1 if self.num_unfinished_live_trials <= 0: # Update the underlying searcher once the # full batch is completed. for trial_id, (result, error) in self.cached_results.items(): self.searcher.on_trial_complete( trial_id, result=result, error=error ) self.live_trials.remove(trial_id) self.cached_results = {} self.num_unfinished_live_trials = 0 else: return else: self.searcher.on_trial_complete(trial_id, result=result, error=error) self.live_trials.remove(trial_id) self.num_unfinished_live_trials -= 1 def on_trial_result(self, trial_id: str, result: Dict) -> None: self.searcher.on_trial_result(trial_id, result) def add_evaluated_point( self, parameters: Dict, value: float, error: bool = False, pruned: bool = False, intermediate_values: Optional[List[float]] = None, ): return self.searcher.add_evaluated_point( parameters, value, error, pruned, intermediate_values ) def get_state(self) -> Dict: state = self.__dict__.copy() del state["searcher"] return copy.deepcopy(state) def set_state(self, state: Dict): self.__dict__.update(state) def save(self, checkpoint_path: str): self.searcher.save(checkpoint_path) def restore(self, checkpoint_path: str): self.searcher.restore(checkpoint_path) # BOHB Specific. # TODO(team-ml): Refactor alongside HyperBandForBOHB def on_pause(self, trial_id: str): self.searcher.on_pause(trial_id) def on_unpause(self, trial_id: str): self.searcher.on_unpause(trial_id)