import copy
import logging
from typing import Dict, List, Optional
from ray.tune.search.searcher import Searcher
from ray.tune.search.util import _set_search_properties_backwards_compatible
from ray.util.annotations import PublicAPI
logger = logging.getLogger(__name__)
[docs]
@PublicAPI
class ConcurrencyLimiter(Searcher):
"""A wrapper algorithm for limiting the number of concurrent trials.
Certain Searchers have their own internal logic for limiting
the number of concurrent trials. If such a Searcher is passed to a
``ConcurrencyLimiter``, the ``max_concurrent`` of the
``ConcurrencyLimiter`` will override the ``max_concurrent`` value
of the Searcher. The ``ConcurrencyLimiter`` will then let the
Searcher's internal logic take over.
Args:
searcher: Searcher object that the
ConcurrencyLimiter will manage.
max_concurrent: Maximum concurrent samples from the underlying
searcher.
batch: Whether to wait for all concurrent samples
to finish before updating the underlying searcher.
Example:
.. code-block:: python
from ray.tune.search import ConcurrencyLimiter
search_alg = HyperOptSearch(metric="accuracy")
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
tuner = tune.Tuner(
trainable,
tune_config=tune.TuneConfig(
search_alg=search_alg
),
)
tuner.fit()
"""
def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
assert type(max_concurrent) is int and max_concurrent > 0
self.searcher = searcher
self.max_concurrent = max_concurrent
self.batch = batch
self.live_trials = set()
self.num_unfinished_live_trials = 0
self.cached_results = {}
self._limit_concurrency = True
if not isinstance(searcher, Searcher):
raise RuntimeError(
f"The `ConcurrencyLimiter` only works with `Searcher` "
f"objects (got {type(searcher)}). Please try to pass "
f"`max_concurrent` to the search generator directly."
)
self._set_searcher_max_concurrency()
super(ConcurrencyLimiter, self).__init__(
metric=self.searcher.metric, mode=self.searcher.mode
)
def _set_searcher_max_concurrency(self):
# If the searcher has special logic for handling max concurrency,
# we do not do anything inside the ConcurrencyLimiter
self._limit_concurrency = not self.searcher.set_max_concurrency(
self.max_concurrent
)
def set_max_concurrency(self, max_concurrent: int) -> bool:
# Determine if this behavior is acceptable, or if it should
# raise an exception.
self.max_concurrent = max_concurrent
return True
def set_search_properties(
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
) -> bool:
self._set_searcher_max_concurrency()
return _set_search_properties_backwards_compatible(
self.searcher.set_search_properties, metric, mode, config, **spec
)
def suggest(self, trial_id: str) -> Optional[Dict]:
if not self._limit_concurrency:
return self.searcher.suggest(trial_id)
assert (
trial_id not in self.live_trials
), f"Trial ID {trial_id} must be unique: already found in set."
if len(self.live_trials) >= self.max_concurrent:
logger.debug(
f"Not providing a suggestion for {trial_id} due to "
"concurrency limit: %s/%s.",
len(self.live_trials),
self.max_concurrent,
)
return
suggestion = self.searcher.suggest(trial_id)
if suggestion not in (None, Searcher.FINISHED):
self.live_trials.add(trial_id)
self.num_unfinished_live_trials += 1
return suggestion
def on_trial_complete(
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
):
if not self._limit_concurrency:
return self.searcher.on_trial_complete(trial_id, result=result, error=error)
if trial_id not in self.live_trials:
return
elif self.batch:
self.cached_results[trial_id] = (result, error)
self.num_unfinished_live_trials -= 1
if self.num_unfinished_live_trials <= 0:
# Update the underlying searcher once the
# full batch is completed.
for trial_id, (result, error) in self.cached_results.items():
self.searcher.on_trial_complete(
trial_id, result=result, error=error
)
self.live_trials.remove(trial_id)
self.cached_results = {}
self.num_unfinished_live_trials = 0
else:
return
else:
self.searcher.on_trial_complete(trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
self.num_unfinished_live_trials -= 1
def on_trial_result(self, trial_id: str, result: Dict) -> None:
self.searcher.on_trial_result(trial_id, result)
def add_evaluated_point(
self,
parameters: Dict,
value: float,
error: bool = False,
pruned: bool = False,
intermediate_values: Optional[List[float]] = None,
):
return self.searcher.add_evaluated_point(
parameters, value, error, pruned, intermediate_values
)
def get_state(self) -> Dict:
state = self.__dict__.copy()
del state["searcher"]
return copy.deepcopy(state)
def set_state(self, state: Dict):
self.__dict__.update(state)
def save(self, checkpoint_path: str):
self.searcher.save(checkpoint_path)
def restore(self, checkpoint_path: str):
self.searcher.restore(checkpoint_path)
# BOHB Specific.
# TODO(team-ml): Refactor alongside HyperBandForBOHB
def on_pause(self, trial_id: str):
self.searcher.on_pause(trial_id)
def on_unpause(self, trial_id: str):
self.searcher.on_unpause(trial_id)