Source code for ray.tune.execution.trial_runner

from typing import Any, Dict, List, Optional, Union, Tuple, Set

from datetime import datetime
import json
import logging
import os
from pathlib import Path
import time
import traceback
import warnings

import ray
from ray.air._internal.uri_utils import URI
from ray.air.config import CheckpointConfig
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.exceptions import RayTaskError
from ray.tune.error import _TuneStopTrialError, _TuneRestoreError
from ray.tune.execution.experiment_state import (
    _ExperimentCheckpointManager,
    _find_newest_experiment_checkpoint,
    _experiment_checkpoint_exists,
)
from ray.util import get_node_ip_address
from ray.tune import TuneError
from ray.tune.callback import CallbackList, Callback
from ray.tune.experiment import Experiment
from ray.tune.execution.insufficient_resources_manager import (
    _InsufficientResourcesManager,
)
from ray.tune.execution.ray_trial_executor import (
    RayTrialExecutor,
    _ExecutorEventType,
    _ExecutorEvent,
)
from ray.tune.result import (
    DEBUG_METRICS,
    DEFAULT_METRIC,
    DONE,
    TIME_THIS_ITER_S,
    RESULT_DUPLICATE,
    SHOULD_CHECKPOINT,
)
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.stopper import NoopStopper, Stopper
from ray.tune.search import BasicVariantGenerator, SearchAlgorithm
from ray.tune.syncer import SyncConfig, get_node_to_storage_syncer
from ray.tune.experiment import Trial
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.utils.log import Verbosity, has_verbosity
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.tune.web_server import TuneServer
from ray.util.annotations import DeveloperAPI, Deprecated
from ray.util.debug import log_once

MAX_DEBUG_TRIALS = 20

logger = logging.getLogger(__name__)


class _TuneControllerBase:
    """A TrialRunner implements the event loop for scheduling trials on Ray.

    .. code-block: python

        runner = TrialRunner()
        runner.add_trial(Trial(...))
        runner.add_trial(Trial(...))
        while not runner.is_finished():
            runner.step()

    The main job of TrialRunner is scheduling trials to efficiently use cluster
    resources, without overloading the cluster.

    While Ray itself provides resource management for tasks and actors, this is
    not sufficient when scheduling trials that may instantiate multiple actors.
    This is because if insufficient resources are available, concurrent trials
    could deadlock waiting for new resources to become available. Furthermore,
    oversubscribing the cluster could degrade training performance, leading to
    misleading benchmark results.

    Args:
        search_alg: SearchAlgorithm for generating
            Trial objects.
        scheduler: Defaults to FIFOScheduler.
        experiment_path: Path where global experiment state checkpoints
            are saved and restored from.
        sync_config: See :class:`~ray.tune.syncer.SyncConfig`.
            Within sync config, the `upload_dir` specifies cloud storage, and
            experiment state checkpoints will be synced to the `remote_checkpoint_dir`:
            `{sync_config.upload_dir}/{experiment_name}`.
        experiment_dir_name: Experiment directory name.
            See :class:`~ray.tune.experiment.Experiment`.
        stopper: Custom class for stopping whole experiments. See ``Stopper``.
        resume: see `tune.py:run`.
        server_port: Port number for launching TuneServer.
        fail_fast: Finishes as soon as a trial fails if True.
            If fail_fast='raise' provided, Tune will automatically
            raise the exception received by the Trainable. fail_fast='raise'
            can easily leak resources and should be used with caution.
        checkpoint_period: Trial runner checkpoint periodicity in
            seconds. Defaults to ``"auto"``, which adjusts checkpointing
            time so that at most 5% of the time is spent on writing
            checkpoints.
        callbacks: List of callbacks that will be called at different
            times in the training loop. Must be instances of the
            ``ray.tune.execution.trial_runner.Callback`` class.
        metric: Metric used to check received results. If a result is
            reported without this metric, an error will be raised. The error
            can be omitted by not providing a metric or by setting the env
            variable ``TUNE_DISABLE_STRICT_METRIC_CHECKING=0``

    """

    CKPT_FILE_TMPL = "experiment_state-{}.json"
    RAISE = "RAISE"

    def __init__(
        self,
        *,
        search_alg: Optional[SearchAlgorithm] = None,
        placeholder_resolvers: Optional[Dict[Tuple, Any]] = None,
        scheduler: Optional[TrialScheduler] = None,
        experiment_path: Optional[str] = None,
        sync_config: Optional[SyncConfig] = None,
        experiment_dir_name: Optional[str] = None,
        stopper: Optional[Stopper] = None,
        resume: Union[str, bool] = False,
        server_port: Optional[int] = None,
        fail_fast: bool = False,
        checkpoint_period: Union[str, int] = None,
        callbacks: Optional[List[Callback]] = None,
        metric: Optional[str] = None,
        trial_checkpoint_config: Optional[CheckpointConfig] = None,
    ):
        self._search_alg = search_alg or BasicVariantGenerator()
        self._placeholder_resolvers = placeholder_resolvers
        self._scheduler_alg = scheduler or FIFOScheduler()
        self._callbacks = CallbackList(callbacks or [])
        self._insufficient_resources_manager = _InsufficientResourcesManager()
        self._pending_trial_queue_times = {}

        self._max_pending_trials = _get_max_pending_trials(self._search_alg)

        self._sync_config = sync_config or SyncConfig()

        self._experiment_dir_name = experiment_dir_name

        # Rename for better code readability
        local_experiment_path = experiment_path
        remote_experiment_path = None

        if self._sync_config.upload_dir and self._experiment_dir_name:
            remote_experiment_path = str(
                URI(self._sync_config.upload_dir) / self._experiment_dir_name
            )

        self._local_experiment_path = local_experiment_path
        self._remote_experiment_path = remote_experiment_path

        self._metric = metric

        self._total_time = 0
        self._iteration = 0
        self._has_errored = False
        self._fail_fast = fail_fast
        if isinstance(self._fail_fast, str):
            self._fail_fast = self._fail_fast.upper()
            if self._fail_fast == self.RAISE:
                warnings.warn(
                    "fail_fast='raise' detected. Be careful when using this "
                    "mode as resources (such as Ray processes, "
                    "file descriptors, and temporary files) may not be "
                    "cleaned up properly. To use "
                    "a safer mode, use fail_fast=True."
                )
            else:
                raise ValueError(
                    "fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}."
                )

        self._print_trial_errors = bool(
            int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1"))
        )

        self._server = None
        self._server_port = server_port
        if server_port is not None:
            self._server = TuneServer(self, self._server_port)

        self._trials: List[Trial] = []
        self._live_trials: Set[Trial] = set()  # Set of non-terminated trials
        self._cached_trial_decisions = {}
        self._queued_trial_decisions = {}

        self._stop_queue = []
        self._should_stop_experiment = False  # used by TuneServer

        if self._local_experiment_path:
            os.makedirs(self._local_experiment_path, exist_ok=True)

        self._stopper = stopper or NoopStopper()

        self._start_time = time.time()
        self._last_checkpoint_time = -float("inf")

        self._session_str = datetime.fromtimestamp(self._start_time).strftime(
            "%Y-%m-%d_%H-%M-%S"
        )

        if checkpoint_period is None:
            checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")

        self._checkpoint_period = checkpoint_period
        self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig()
        self._checkpoint_manager = self._create_checkpoint_manager()

        self._resumed = False
        resume_config = self._checkpoint_manager.resume(resume_type=resume)

        if resume_config:
            try:
                self.resume(
                    resume_unfinished=resume_config.resume_unfinished,
                    resume_errored=resume_config.resume_errored,
                    restart_errored=resume_config.restart_errored,
                )
                self._resumed = True
            except Exception as e:
                if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
                    logger.error(str(e))
                logger.exception("Runner restore failed.")
                if self._fail_fast:
                    raise
                logger.info("Restarting experiment.")
        else:
            logger.debug("Starting a new experiment.")

    def _wrapped(self):
        raise RuntimeError

    @property
    def resumed(self):
        return self._resumed

    @property
    def search_alg(self):
        return self._search_alg

    @property
    def scheduler_alg(self):
        return self._scheduler_alg

    def setup_experiments(
        self, experiments: List[Experiment], total_num_samples: int
    ) -> None:
        """Obtains any necessary information from experiments.

        Mainly used to setup callbacks.

        Args:
            experiments: List of Experiments
                to use.
            total_num_samples: Total number of samples
                factoring in grid search samplers.
        """
        experiment = experiments[0]
        spec = experiment.public_spec if experiment else {}
        spec["total_num_samples"] = total_num_samples
        self._callbacks.setup(**spec)

    def end_experiment_callbacks(self) -> None:
        """Calls ``on_experiment_end`` method in callbacks."""
        self._callbacks.on_experiment_end(trials=self._trials)

    @Deprecated("Use `TrialRunner.experiment_state_path` instead.")
    @property
    def checkpoint_file(self) -> str:
        return self.experiment_state_path

    @property
    def experiment_state_file_name(self) -> str:
        return self.CKPT_FILE_TMPL.format(self._session_str)

    @property
    def experiment_state_path(self) -> str:
        return os.path.join(
            self._local_experiment_path, self.experiment_state_file_name
        )

    def _create_checkpoint_manager(self):
        return _ExperimentCheckpointManager(
            local_checkpoint_dir=self._local_experiment_path,
            remote_checkpoint_dir=self._remote_experiment_path,
            checkpoint_period=self._checkpoint_period,
            sync_config=self._sync_config,
            sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
        )

    @property
    def _remote_checkpoint_dir(self):
        if self._sync_config.upload_dir and self._experiment_dir_name:
            return str(URI(self._sync_config.upload_dir) / self._experiment_dir_name)
        return None

    @classmethod
    def checkpoint_exists(cls, directory: str) -> bool:
        if not os.path.exists(directory):
            return False

        return _experiment_checkpoint_exists(directory)

    def save_to_dir(self, experiment_dir: Optional[str] = None):
        """Save TrialRunner state to experiment directory.

        Accepts an ``experiment_dir`` argument which defaults to the
        local checkpoint directory.

        This method will save the trial runner state, the searcher state,
        and the callback states into the experiment directory.
        """
        experiment_dir = experiment_dir or self._local_experiment_path

        # Get state from trial executor and runner
        runner_state = {
            # Trials
            "checkpoints": list(self._get_trial_checkpoints().values()),
            # Experiment data
            "runner_data": self.__getstate__(),
            # Metadata
            "stats": {
                "start_time": self._start_time,
                "timestamp": self._last_checkpoint_time,
            },
        }

        tmp_file_name = os.path.join(experiment_dir, ".tmp_experiment_state")

        with open(tmp_file_name, "w") as f:
            json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)

        os.replace(
            tmp_file_name,
            os.path.join(experiment_dir, self.experiment_state_file_name),
        )

        self._search_alg.save_to_dir(
            self._local_experiment_path, session_str=self._session_str
        )
        self._callbacks.save_to_dir(
            self._local_experiment_path, session_str=self._session_str
        )

    def restore_from_dir(self, experiment_dir: Optional[str] = None) -> List[Trial]:
        """Restore TrialRunner state from experiment directory.

        Accepts an ``experiment_dir`` argument which defaults to the
        local checkpoint directory.

        This method will restore the trial runner state, the searcher state,
        and the callback states. It will then parse the trial states
        and return them as a list of Trial objects.
        """
        experiment_dir = experiment_dir or self._local_experiment_path

        # Update local checkpoint dir
        self._local_experiment_path = experiment_dir

        # Find newest state file
        newest_state_path = _find_newest_experiment_checkpoint(
            self._local_experiment_path
        )

        if not newest_state_path:
            raise ValueError(
                f"Tried to resume experiment from directory "
                f"`{self._local_experiment_path}`, but no "
                f"experiment checkpoint data was found."
            )

        # Set checkpoint file to load
        logger.warning(
            f"Attempting to resume experiment from {self._local_experiment_path}. "
            "This will ignore any new changes to the specification."
        )
        logger.info(
            "Using the newest experiment state file found within the "
            f"experiment directory: {Path(newest_state_path).name}"
        )

        # Actually load data
        with open(newest_state_path, "r") as f:
            runner_state = json.load(f, cls=TuneFunctionDecoder)

        # 1. Restore trial runner state
        self.__setstate__(runner_state["runner_data"])

        # 2. Restore search algorithm and callback state
        if self._search_alg.has_checkpoint(self._local_experiment_path):
            self._search_alg.restore_from_dir(self._local_experiment_path)

        if self._callbacks.can_restore(self._local_experiment_path):
            self._callbacks.restore_from_dir(self._local_experiment_path)

        # 3. Load trials
        trials = []
        for trial_json_state in runner_state["checkpoints"]:
            trial = Trial.from_json_state(trial_json_state)

            # The following properties may be updated on restoration
            # Ex: moved local/cloud experiment directory
            # ATTN: Set `local_experiment_path` to update trial checkpoints!
            trial.local_experiment_path = self._local_experiment_path
            trial.remote_experiment_path = self._remote_experiment_path
            trial.sync_config = self._sync_config
            trial.experiment_dir_name = self._experiment_dir_name

            # Avoid creating logdir in client mode for returned trial results,
            # since the dir might not be creatable locally.
            # TODO(ekl) this is kind of a hack.
            if not ray.util.client.ray.is_connected():
                trial.init_local_path()  # Create logdir if it does not exist

            trials.append(trial)

        return trials

    def checkpoint(self, force: bool = False, wait: bool = False):
        """Saves execution state to `self._local_experiment_path`.

        Overwrites the current session checkpoint, which starts when self
        is instantiated. Throttle depends on self._checkpoint_period.

        Also automatically saves the search algorithm to the local
        checkpoint dir.

        Args:
            force: Forces a checkpoint despite checkpoint_period.
            wait: Wait until syncing to cloud has finished.

        """
        with warn_if_slow(
            "experiment_checkpoint",
            message="Checkpointing the experiment state took "
            "{duration:.3f} s, which may be a performance "
            "bottleneck. Please ensure the "
            "`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
            "something significantly higher than this duration "
            "to ensure compute time is mostly spent on the main "
            "training loop.",
            # No backlog warning if forced checkpoint as we wait
            # for previous sync to finish.
            disable=self._checkpoint_manager.auto_checkpoint_enabled or force or wait,
        ):
            self._checkpoint_manager.checkpoint(
                save_fn=self.save_to_dir, force=force, wait=wait
            )

    def resume(
        self,
        resume_unfinished: bool = True,
        resume_errored: bool = False,
        restart_errored: bool = False,
    ):
        """Resumes all checkpointed trials from previous run.

        Requires user to manually re-register their objects. Also stops
        all ongoing trials.
        """
        trials = self.restore_from_dir()

        # Set trial statuses according to the resume configuration
        for trial in sorted(trials, key=lambda t: t.last_update_time, reverse=True):
            trial_to_add = trial
            if trial.status == Trial.ERROR:
                if resume_errored:
                    # Keep trial ID on resume
                    trial_to_add.error_filename = None
                    trial_to_add.pickled_error_filename = None
                    trial_to_add.set_status(Trial.PENDING)
                    trial_to_add.restore_path = trial.checkpoint.dir_or_data
                elif restart_errored:
                    trial_to_add = trial.reset()
                    trial_to_add.restore_path = None
            elif trial.status != Trial.TERMINATED and not resume_unfinished:
                trial_to_add.status = Trial.TERMINATED
            self.add_trial(trial_to_add)

    def update_pending_trial_resources(
        self, resources: Union[dict, PlacementGroupFactory]
    ):
        """Update trial resources when resuming from checkpoint.

        Only updating the pending ones.
        """
        assert resources
        if isinstance(resources, dict) and "gpu" not in resources:
            resources["gpu"] = 0
        for trial in self._trials:
            if trial.status == Trial.PENDING:
                trial.update_resources(resources=resources)

    def is_finished(self):
        """Returns whether all trials have finished running."""
        # The checks here are partly redundant but optimized for quick
        # evaluation. Specifically, if there are live trials, we check
        # these live trials first. Only if none of the live trials is
        # live anymore do we loop over all trials for a final check.
        trials_done = (
            len(self._live_trials) == 0
            or all(trial.is_finished() for trial in self._live_trials)
        ) and all(trial.is_finished() for trial in self._trials)
        return trials_done and self._search_alg.is_finished()

    def get_trial(self, tid):
        trial = [t for t in self._trials if t.trial_id == tid]
        return trial[0] if trial else None

    def get_trials(self):
        """Returns the list of trials managed by this TrialRunner.

        Note that the caller usually should not mutate trial state directly.
        """
        return self._trials

    def get_live_trials(self):
        """Returns the set of trials that are not in Trial.TERMINATED state."""
        return self._live_trials

    def _get_trial_checkpoints(self) -> Dict[str, str]:
        raise NotImplementedError

    def _mark_trial_to_checkpoint(self, trial: Trial):
        raise NotImplementedError

    def _set_trial_status(self, trial: Trial, status: str):
        raise NotImplementedError

    def _cleanup_trials(self):
        raise NotImplementedError

    def add_trial(self, trial: Trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial: Trial to queue.
        """
        # If the config map has had all the references replaced with placeholders,
        # resolve them before adding the trial.
        if self._placeholder_resolvers:
            trial.resolve_config_placeholders(self._placeholder_resolvers)

        # With trial.config resolved, create placement group factory if needed.
        trial.create_placement_group_factory()

        self._trials.append(trial)
        if trial.status != Trial.TERMINATED:
            self._live_trials.add(trial)
        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self._wrapped(), trial)
        self._mark_trial_to_checkpoint(trial)

    def _used_resources_string(self) -> str:
        raise NotImplementedError

    def debug_string(self, delim="\n"):
        from ray.tune.progress_reporter import _trial_progress_str

        result_keys = [list(t.last_result) for t in self.get_trials() if t.last_result]
        metrics = set().union(*result_keys)
        messages = [
            self._scheduler_alg.debug_string(),
            self._used_resources_string(),
            _trial_progress_str(self.get_trials(), metrics, force_table=True),
        ]
        return delim.join(messages)

    def step(self):
        raise NotImplementedError

    def _maybe_execute_queued_decision(self, trial):
        # `self._queued_trial_decisions` now contains a final decision
        # based on all results
        final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
        if final_decision:
            logger.debug(
                f"Executing final queued decision for {trial}: {final_decision}"
            )
            self._execute_action(trial, final_decision)

    def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None):
        raise NotImplementedError

    def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True):
        raise NotImplementedError

    def _stop_experiment_if_needed(self):
        """Stops all trials."""
        fail_fast = self._fail_fast and self._has_errored
        if self._stopper.stop_all() or fail_fast or self._should_stop_experiment:
            self._search_alg.set_finished()
            [
                self._schedule_trial_stop(t)
                for t in self._trials
                if t.status not in {Trial.ERROR, Trial.TERMINATED}
            ]

    ###
    # FAILURE

    def _process_trial_failure(
        self, trial: Trial, exception: Optional[Union[TuneError, RayTaskError]] = None
    ):
        """Handle trial failure.

        Attempt trial recovery if possible, clean up state otherwise.

        Args:
            trial: Failed trial.
            exception: Exception prior to invoking this method.
        """
        self._has_errored = True
        if trial.status == Trial.RUNNING:
            if trial.should_recover():
                self._try_recover(trial, exc=exception)
            else:
                self._scheduler_alg.on_trial_error(self, trial)
                self._search_alg.on_trial_complete(trial.trial_id, error=True)
                self._callbacks.on_trial_error(
                    iteration=self._iteration, trials=self._trials, trial=trial
                )
                self._schedule_trial_stop(trial, exception=exception)

    ###
    # STOP

    def stop_trial(self, trial):
        """The canonical implementation of stopping a trial.

        Trials may be in any external status when this function is called.
        If trial is in state PENDING or PAUSED, calls `on_trial_remove` for
        scheduler and `on_trial_complete()` for search_alg.
        If trial is in state RUNNING, calls `on_trial_complete` for scheduler
        and search_alg if RUNNING. Caller to ensure that there is no
        outstanding future to be handled for the trial. If there is, the future
        would be discarded.
        """
        try:
            if trial.status in [Trial.ERROR, Trial.TERMINATED]:
                return
            elif trial.status in [Trial.PENDING, Trial.PAUSED]:
                self._scheduler_alg.on_trial_remove(self, trial)
                self._search_alg.on_trial_complete(trial.trial_id)
            elif trial.status is Trial.RUNNING:
                # By this time trial.last_result should have been
                # updated already.
                self._scheduler_alg.on_trial_complete(
                    self, trial, flatten_dict(trial.last_result)
                )
                self._search_alg.on_trial_complete(
                    trial.trial_id, result=flatten_dict(trial.last_result)
                )
            self._callbacks.on_trial_complete(
                iteration=self._iteration, trials=self._trials, trial=trial
            )
            self._schedule_graceful_trial_stop(trial)
            self._live_trials.discard(trial)
        except Exception as e:
            logger.exception("Trial %s: Error stopping trial.", trial)
            if self._fail_fast == self.RAISE:
                raise
            if isinstance(e, TuneError):
                self._process_trial_failure(trial, exception=e)
            else:
                self._process_trial_failure(
                    trial, _TuneStopTrialError(traceback.format_exc())
                )

    def _schedule_graceful_trial_stop(self, trial: Trial):
        raise NotImplementedError

    ###
    # TRAIN

    def _schedule_trial_train(self, trial: Trial):
        raise NotImplementedError

    def _on_training_result(self, trial, result):
        if not isinstance(result, list):
            result = [result]
        with warn_if_slow("process_trial_result"):
            self._process_trial_results(trial, result)
        self._maybe_execute_queued_decision(trial)

    def _process_trial_results(self, trial, results):
        logger.debug(f"Processing trial results for trial {trial}: {results}")
        with warn_if_slow(
            "process_trial_results",
            message="Processing trial results took {duration:.3f} s, "
            "which may be a performance bottleneck. Please consider "
            "reporting results less frequently to Ray Tune.",
        ):
            for i, result in enumerate(results):
                with warn_if_slow("process_trial_result"):
                    decision = self._process_trial_result(trial, result)
                if decision is None:
                    # If we didn't get a decision, this means a
                    # non-training future (e.g. a save) was scheduled.
                    # We do not allow processing more results then.
                    if i < len(results) - 1:
                        if log_once("trial_runner_buffer_checkpoint"):
                            logger.warning(
                                f"Trial {trial} has a non-training future "
                                f"scheduled but {len(results) - i} results "
                                f"left to process. This means that a "
                                f"checkpoint was requested, but buffered "
                                f"training was continued before it was "
                                f"saved. Consider using non-buffered "
                                f"training by setting the env variable "
                                f"`TUNE_RESULT_BUFFER_LENGTH=1`."
                            )
                elif decision == TrialScheduler.STOP:
                    # If the decision is to stop the trial,
                    # ignore all results that came after that.
                    break

    def _process_trial_result(self, trial, result):
        result.update(trial_id=trial.trial_id)
        is_duplicate = RESULT_DUPLICATE in result
        force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
        # TrialScheduler and SearchAlgorithm still receive a
        # notification because there may be special handling for
        # the `on_trial_complete` hook.
        if is_duplicate:
            logger.debug("Trial finished without logging 'done'.")
            result = trial.last_result
            result.update(done=True)

        self._total_time += result.get(TIME_THIS_ITER_S, 0)

        flat_result = flatten_dict(result)
        self._validate_result_metrics(flat_result)

        if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result):
            decision = TrialScheduler.STOP
        else:
            with warn_if_slow("scheduler.on_trial_result"):
                decision = self._scheduler_alg.on_trial_result(
                    self._wrapped(), trial, flat_result
                )
        if decision == TrialScheduler.STOP:
            result.update(done=True)
        else:
            # Only updating search alg if the trial is not to be stopped.
            with warn_if_slow("search_alg.on_trial_result"):
                self._search_alg.on_trial_result(trial.trial_id, flat_result)

        # If this is not a duplicate result, the callbacks should
        # be informed about the result.
        if not is_duplicate:
            with warn_if_slow("callbacks.on_trial_result"):
                self._callbacks.on_trial_result(
                    iteration=self._iteration,
                    trials=self._trials,
                    trial=trial,
                    result=result.copy(),
                )
            trial.update_last_result(result)
            # Include in next experiment checkpoint
            self._mark_trial_to_checkpoint(trial)

        # Checkpoints to disk. This should be checked even if
        # the scheduler decision is STOP or PAUSE. Note that
        # PAUSE only checkpoints to memory and does not update
        # the global checkpoint state.
        self._checkpoint_trial_if_needed(trial, force=force_checkpoint)

        if trial.is_saving:
            logger.debug(f"Caching trial decision for trial {trial}: {decision}")
            # Cache decision to execute on after the save is processed.
            # This prevents changing the trial's state or kicking off
            # another training step prematurely.
            self._cached_trial_decisions[trial.trial_id] = decision
            return None
        else:
            self._queue_decision(trial, decision)
            return decision

    def _validate_result_metrics(self, result):
        """
        Check if any of the required metrics was not reported
        in the last result. If the only items are ``done`` or any of
        DEBUG_METRICS, this means that no result was ever received and
        the trial just returned. This is also okay and will not raise
        an error.

        This will ignore checking for the DEFAULT_METRIC.
        """
        if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (
            len({k for k in result if k not in list(DEBUG_METRICS) + [DONE]}) > 1
        ):
            base_metric = self._metric if self._metric != DEFAULT_METRIC else None
            scheduler_metric = (
                self._scheduler_alg.metric
                if self._scheduler_alg.metric != DEFAULT_METRIC
                else None
            )
            search_metrics = (
                self._search_alg.metric
                if self._search_alg.metric != DEFAULT_METRIC
                else None
            )

            if isinstance(search_metrics, str):
                search_metrics = [search_metrics]

            if base_metric and base_metric not in result:
                report_metric = base_metric
                location = "tune.TuneConfig()"
            elif scheduler_metric and scheduler_metric not in result:
                report_metric = scheduler_metric
                location = type(self._scheduler_alg).__name__
            elif search_metrics and any(
                search_metric not in result for search_metric in search_metrics
            ):
                report_metric = list(
                    filter(
                        lambda search_metric: search_metric not in result,
                        search_metrics,
                    )
                )
                if len(report_metric) == 1:
                    report_metric = report_metric[0]
                location = type(self._search_alg).__name__
            else:
                report_metric = None
                location = None

            if report_metric:
                raise ValueError(
                    "Trial returned a result which did not include the "
                    "specified metric(s) `{}` that `{}` expects. "
                    "Make sure your calls to `tune.report()` include the "
                    "metric, or set the "
                    "TUNE_DISABLE_STRICT_METRIC_CHECKING "
                    "environment variable to 1. Result: {}".format(
                        report_metric, location, result
                    )
                )

    ###
    # SAVE

    def _schedule_trial_save(
        self,
        trial: Trial,
        storage: CheckpointStorage = CheckpointStorage.PERSISTENT,
        result: Optional[Dict] = None,
    ) -> Optional[_TrackedCheckpoint]:
        raise NotImplementedError

    def _on_saving_result(self, trial, checkpoint_value: Union[ray.ObjectRef, str]):
        with warn_if_slow("process_trial_save") as _profile:
            self._process_trial_save(trial, checkpoint_value)
        with warn_if_slow("callbacks.on_trial_save"):
            self._callbacks.on_trial_save(
                iteration=self._iteration, trials=self._trials, trial=trial
            )
        if _profile.too_slow and trial.sync_on_checkpoint:
            # TODO(ujvl): Suggest using cloud checkpointing once
            #  API has converged.

            msg = (
                "Consider turning off forced head-worker trial "
                "checkpoint syncs by setting sync_on_checkpoint=False"
                ". Note that this may result in faulty trial "
                "restoration if a failure occurs while the checkpoint "
                "is being synced from the worker to the head node."
            )

            if trial.location.hostname and (
                trial.location.hostname != get_node_ip_address()
            ):
                if log_once("tune_head_worker_checkpoint"):
                    logger.warning(msg)

        self._maybe_execute_queued_decision(trial)

    def _process_trial_save(
        self, trial: Trial, checkpoint_value: Union[ray.ObjectRef, str]
    ):
        """Processes a trial save.

        Acts on the decision cached during the last `_process_trial` call.

        Args:
            trial: Trial being saved.
        """
        logger.debug("Trial %s: Processing trial save.", trial)

        try:
            trial.saving_to.dir_or_data = checkpoint_value
            self._callbacks.on_checkpoint(
                iteration=self._iteration,
                trials=self._trials,
                trial=trial,
                checkpoint=trial.saving_to,
            )
            trial.on_checkpoint(trial.saving_to)
            self._checkpoint_manager.on_trial_checkpoint(trial)
            if trial.checkpoint.storage_mode != CheckpointStorage.MEMORY:
                self._mark_trial_to_checkpoint(trial)
        except Exception:
            logger.exception(
                "Trial %s: Error handling checkpoint %s", trial, checkpoint_value
            )
            if self._fail_fast == TrialRunner.RAISE:
                raise

        trial.saving_to = None
        decision = self._cached_trial_decisions.pop(trial.trial_id, None)
        if decision and checkpoint_value:
            self._queue_decision(trial, decision)

    ###
    # RESTORE

    def _schedule_trial_restore(self, trial: Trial):
        raise NotImplementedError

    def _on_restoring_result(self, trial):
        with warn_if_slow("process_trial_restore"):
            self._process_trial_restore(trial)
        with warn_if_slow("callbacks.on_trial_restore"):
            self._callbacks.on_trial_restore(
                iteration=self._iteration, trials=self._trials, trial=trial
            )

    def _process_trial_restore(self, trial: Trial):
        """Processes a trial restore.

        Args:
            trial: Trial being restored.
        """
        logger.debug("Trial %s: Processing trial restore.", trial)
        trial.on_restore()
        logger.debug("Trial %s: Restore processed successfully", trial)
        self._set_trial_status(trial, Trial.RUNNING)
        self._schedule_trial_train(trial)
        self._live_trials.add(trial)

    ###
    # EXPORT
    def _schedule_trial_export(self, trial: Trial):
        raise NotImplementedError

    def _queue_decision(self, trial, decision):
        # Get old decision, setting it to the current decision if it isn't set
        old_decision = self._queued_trial_decisions.setdefault(trial.trial_id, decision)

        # Stopping always takes precedence. If we decided to stop, just quit
        if old_decision is TrialScheduler.STOP:
            return

        # The old decision wasn't STOP. We update the decision only if it is
        # STOP or PAUSE. The action will only be CONTINUE if it was set by
        # the first received result and was never updated after that.
        if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE:
            self._queued_trial_decisions[trial.trial_id] = decision

    def _execute_action(self, trial: Trial, decision: str):
        """Executes action based on decision.

        Args:
            trial: Trial to act on.
            decision: Scheduling decision to undertake.
        """
        if decision == TrialScheduler.CONTINUE:
            self._schedule_trial_train(trial)
        elif decision == TrialScheduler.PAUSE:
            self.pause_trial(trial)
        elif decision == TrialScheduler.STOP:
            self.stop_trial(trial)
        elif decision == TrialScheduler.NOOP:
            pass
        else:
            raise ValueError("Invalid decision: {}".format(decision))

    def _checkpoint_trial_if_needed(self, trial, force=False):
        """Checkpoints trial based off trial.last_result."""
        if trial.should_checkpoint() or force:
            # Save trial runtime if possible.
            if trial.runner:
                self._schedule_trial_save(trial, storage=CheckpointStorage.PERSISTENT)

    def _try_recover(self, trial: Trial, exc: Union[TuneError, RayTaskError]):
        """Tries to recover trial.

        Notifies SearchAlgorithm and Scheduler if failure to recover.

        Args:
            trial: Trial to recover.
            exc: Exception prior to invoking this method.
        """
        self._cached_trial_decisions.pop(trial.trial_id, None)
        # Resetting this, in case that the trial is in saving status when it crashes.
        if trial.is_saving:
            trial.saving_to = None
        if trial.is_restoring and exc:
            exc = _TuneRestoreError(exc)
        self._schedule_trial_stop(trial, exception=exc)

        logger.debug("Trial %s: Notifying Scheduler and requeueing.", trial)
        self._requeue_trial(trial)

    def _requeue_trial(self, trial):
        """Notification to TrialScheduler and requeue trial.

        This does not notify the SearchAlgorithm because the function
        evaluation is still in progress.

        """
        self._scheduler_alg.on_trial_error(self, trial)
        self._set_trial_status(trial, status=Trial.PENDING)

        # TODO(rliaw): Right now, this pushes the trial to the end of queue
        # because restoration can be expensive. However, this is not
        # ideal since it just hides the issue - a better fix would
        # be to use an actor table to detect the IP of the Trainable
        # and rsync the files there.
        # See https://github.com/ray-project/ray/issues/5168
        self._trials.pop(self._trials.index(trial))
        self._trials.append(trial)
        self._live_trials.add(trial)

        with warn_if_slow("scheduler.on_trial_add"):
            self._scheduler_alg.on_trial_add(self._wrapped(), trial)

    def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool:
        """Adds next trials to queue if possible.

        Note that the timeout is currently unexposed to the user.

        Args:
            blocking: Blocks until either a trial is available
                or is_finished (timeout or search algorithm finishes).
            timeout: Seconds before blocking times out.

        Returns:
            Boolean indicating if a new trial was created or not.
        """
        trial = self._search_alg.next_trial()
        if blocking and not trial:
            start = time.time()
            # Checking `is_finished` instead of _search_alg.is_finished
            # is fine because blocking only occurs if all trials are
            # finished and search_algorithm is not yet finished
            while (
                not trial and not self.is_finished() and time.time() - start < timeout
            ):
                logger.debug("Blocking for next trial...")
                trial = self._search_alg.next_trial()
                time.sleep(1)

        if trial:
            self.add_trial(trial)
            return True

        return False

    def request_stop_trial(self, trial):
        self._stop_queue.append(trial)

    def request_stop_experiment(self):
        self._should_stop_experiment = True

    def _process_stop_requests(self):
        while self._stop_queue:
            t = self._stop_queue.pop()
            self.stop_trial(t)

    def pause_trial(self, trial: Trial, should_checkpoint: bool = True):
        """Pause a trial and reset the necessary state variables for resuming later.

        Args:
            trial: Trial to pause.
            should_checkpoint: Whether or not an in-memory checkpoint should be created
                for this paused trial. Defaults to True.
        """
        # NOTE: The cached trial decision is not needed since we will overrule this
        # decision with PAUSE.
        self._cached_trial_decisions.pop(trial.trial_id, None)
        self._schedule_trial_pause(trial)

    def cleanup(self):
        """Cleanup trials and callbacks."""
        self._cleanup_trials()
        self.end_experiment_callbacks()

    def __getstate__(self):
        """Gets state for trial.

        Note that this is not used as a pickling override as
        does not have all fields.
        """
        state = self.__dict__.copy()
        for k in [
            "_trials",
            "_live_trials",
            "_stop_queue",
            "_server",
            "_search_alg",
            "_placeholder_resolvers",
            "_scheduler_alg",
            "_pending_trial_queue_times",
            "_callbacks",
            "_checkpoint_manager",
            "_local_experiment_path",
            "_remote_experiment_path",
            "_sync_config",
            "_experiment_dir_name",
            "_insufficient_resources_manager",
        ]:
            del state[k]
        state["launch_web_server"] = bool(self._server)
        return state

    def __setstate__(self, state):
        launch_web_server = state.pop("launch_web_server")

        # Use session_str from previous checkpoint if does not exist
        session_str = state.pop("_session_str")
        self.__dict__.setdefault("_session_str", session_str)
        # Use start_time from previous checkpoint if does not exist
        start_time = state.pop("_start_time")
        self.__dict__.setdefault("_start_time", start_time)

        self.__dict__.update(state)
        self._checkpoint_manager = self._create_checkpoint_manager()

        if launch_web_server:
            self._server = TuneServer(self, self._server_port)


[docs]@DeveloperAPI class TrialRunner(_TuneControllerBase): """A TrialRunner implements the event loop for scheduling trials on Ray. .. code-block: python runner = TrialRunner() runner.add_trial(Trial(...)) runner.add_trial(Trial(...)) while not runner.is_finished(): runner.step() print(runner.debug_string()) The main job of TrialRunner is scheduling trials to efficiently use cluster resources, without overloading the cluster. While Ray itself provides resource management for tasks and actors, this is not sufficient when scheduling trials that may instantiate multiple actors. This is because if insufficient resources are available, concurrent trials could deadlock waiting for new resources to become available. Furthermore, oversubscribing the cluster could degrade training performance, leading to misleading benchmark results. Args: search_alg: SearchAlgorithm for generating Trial objects. scheduler: Defaults to FIFOScheduler. experiment_path: Path where global experiment state checkpoints are saved and restored from. experiment_dir_name: Experiment directory name. See :class:`~ray.tune.experiment.Experiment`. sync_config: See :class:`~ray.tune.syncer.SyncConfig`. Within sync config, the `upload_dir` specifies cloud storage, and experiment state checkpoints will be synced to the `remote_checkpoint_dir`: `{sync_config.upload_dir}/{experiment_name}`. stopper: Custom class for stopping whole experiments. See ``Stopper``. resume: see `tune.py:run`. server_port: Port number for launching TuneServer. fail_fast: Finishes as soon as a trial fails if True. If fail_fast='raise' provided, Tune will automatically raise the exception received by the Trainable. fail_fast='raise' can easily leak resources and should be used with caution. checkpoint_period: Trial runner checkpoint periodicity in seconds. Defaults to ``"auto"``, which adjusts checkpointing time so that at most 5% of the time is spent on writing checkpoints. trial_executor: Defaults to RayTrialExecutor. callbacks: List of callbacks that will be called at different times in the training loop. Must be instances of the ``ray.tune.execution.trial_runner.Callback`` class. metric: Metric used to check received results. If a result is reported without this metric, an error will be raised. The error can be omitted by not providing a metric or by setting the env variable ``TUNE_DISABLE_STRICT_METRIC_CHECKING=0`` """ def __init__( self, *, search_alg: Optional[SearchAlgorithm] = None, placeholder_resolvers: Optional[Dict[Tuple, Any]] = None, scheduler: Optional[TrialScheduler] = None, experiment_path: Optional[str] = None, experiment_dir_name: Optional[str] = None, sync_config: Optional[SyncConfig] = None, stopper: Optional[Stopper] = None, resume: Union[str, bool] = False, server_port: Optional[int] = None, fail_fast: bool = False, trial_executor: Optional[RayTrialExecutor] = None, checkpoint_period: Union[str, int] = None, callbacks: Optional[List[Callback]] = None, metric: Optional[str] = None, trial_checkpoint_config: Optional[CheckpointConfig] = None, # Deprecated local_checkpoint_dir: Optional[str] = None, ): if local_checkpoint_dir: if experiment_path: raise ValueError( "Only one of `local_checkpoint_dir` or `experiment_path` " "can be passed to `TrialRunner()`." ) warnings.warn( "The `local_checkpoint_dir` argument is deprecated and will be " "removed in the future. Use `experiment_path` instead." ) experiment_path = local_checkpoint_dir self.trial_executor = trial_executor or RayTrialExecutor() super().__init__( search_alg=search_alg, placeholder_resolvers=placeholder_resolvers, scheduler=scheduler, experiment_path=experiment_path, experiment_dir_name=experiment_dir_name, sync_config=sync_config, stopper=stopper, resume=resume, server_port=server_port, fail_fast=fail_fast, checkpoint_period=checkpoint_period, callbacks=callbacks, metric=metric, trial_checkpoint_config=trial_checkpoint_config, ) self.trial_executor.setup( max_pending_trials=self._max_pending_trials, # TODO(ml-team): Remove these in 2.6. trainable_kwargs={ "sync_timeout": self._sync_config.sync_timeout, "custom_syncer": get_node_to_storage_syncer(self._sync_config), }, ) def _wrapped(self): return TrialRunnerWrapper( self, self.trial_executor, runner_whitelist_attr={"search_alg", "get_trials", "_set_trial_status"}, executor_whitelist_attr={"has_resources_for_trial", "pause_trial", "save"}, ) def _used_resources_string(self) -> str: return self.trial_executor.debug_string() def _get_trial_checkpoints(self) -> Dict[str, str]: return self.trial_executor.get_checkpoints() def _mark_trial_to_checkpoint(self, trial: Trial): self.trial_executor.mark_trial_to_checkpoint(trial) def _set_trial_status(self, trial: Trial, status: str): self.trial_executor.set_status(trial, status=status) def _reconcile_live_trials(self): """Loop through live trials and remove if terminated""" for trial in list(self._live_trials): # Only for TERMINATED trials. ERRORed trials might be retried. if trial.status == Trial.TERMINATED: self._live_trials.remove(trial) def _cleanup_trials(self): self.trial_executor.cleanup() def _update_trial_queue_and_get_next_trial(self) -> Optional[Trial]: """Adding suggested trials to the live queue of trials (they start as PENDING trials). Returns: next_trial: Trial """ wait_for_trial = True # wait for new trials when all trials are finished num_pending_trials = 0 for trial in self._live_trials: if not trial.is_finished(): wait_for_trial = False if trial.status == Trial.PENDING: num_pending_trials += 1 if not self._search_alg.is_finished(): # Create pending trials until it fails. while num_pending_trials < self._max_pending_trials: if not self._update_trial_queue(blocking=wait_for_trial): break wait_for_trial = False # wait at most one trial num_pending_trials += 1 with warn_if_slow("choose_trial_to_run"): return self._scheduler_alg.choose_trial_to_run(self._wrapped())
[docs] def step(self): """Runs one step of the trial event loop. Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ if self.is_finished(): raise TuneError("Called step when all trials finished?") with warn_if_slow("on_step_begin"): self.trial_executor.on_step_begin() with warn_if_slow("callbacks.on_step_begin"): self._callbacks.on_step_begin( iteration=self._iteration, trials=self._trials ) next_trial = self._update_trial_queue_and_get_next_trial() if next_trial: logger.debug(f"Got new trial to run: {next_trial}") self._wait_and_handle_event(next_trial) self._stop_experiment_if_needed() try: self.checkpoint() except Exception as e: logger.warning(f"Trial Runner checkpointing failed: {str(e)}") self._iteration += 1 if self._server: with warn_if_slow("server"): self._process_stop_requests() if self.is_finished(): self._server.shutdown() self._reconcile_live_trials() with warn_if_slow("on_step_end"): self.trial_executor.on_step_end(search_ended=self._search_alg.is_finished()) with warn_if_slow("callbacks.on_step_end"): self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
def _wait_and_handle_event(self, next_trial: Optional[Trial]): try: # Single wait of entire tune loop. event = self.trial_executor.get_next_executor_event( self._live_trials, next_trial is not None ) if event.type == _ExecutorEventType.PG_READY: self._on_pg_ready(next_trial) elif event.type == _ExecutorEventType.NO_RUNNING_TRIAL_TIMEOUT: self._insufficient_resources_manager.on_no_available_trials( self.get_trials() ) elif event.type == _ExecutorEventType.YIELD: pass else: assert event.type in ( _ExecutorEventType.TRAINING_RESULT, _ExecutorEventType.SAVING_RESULT, _ExecutorEventType.RESTORING_RESULT, ) trial = event.trial result = event.result if _ExecutorEvent.KEY_EXCEPTION in result: self._on_executor_error( trial, event.type, result[_ExecutorEvent.KEY_EXCEPTION] ) elif event.type == _ExecutorEventType.RESTORING_RESULT: self._on_restoring_result(trial) else: assert event.type in ( _ExecutorEventType.SAVING_RESULT, _ExecutorEventType.TRAINING_RESULT, ), f"Unexpected future type - {event.type}" if event.type == _ExecutorEventType.TRAINING_RESULT: self._on_training_result( trial, result[_ExecutorEvent.KEY_FUTURE_RESULT] ) else: self._on_saving_result( trial, result[_ExecutorEvent.KEY_FUTURE_RESULT] ) except Exception as e: if e is TuneError or self._fail_fast == self.RAISE: raise e else: raise TuneError(traceback.format_exc()) def _on_pg_ready(self, next_trial: Optional[Trial]): def _start_trial(trial: Trial) -> bool: """Helper function to start trial and call callbacks""" with warn_if_slow("start_trial"): if self.trial_executor.start_trial(trial): self._callbacks.on_trial_start( iteration=self._iteration, trials=self._trials, trial=trial ) return True return False assert next_trial is not None logger.debug(f"Trying to start trial: {next_trial}") trial_started = _start_trial(next_trial) if not trial_started and next_trial.status != Trial.ERROR: # Only try to start another trial if previous trial startup # did not error (e.g. it just didn't start because its # placement group is not ready, yet). # Without this clause, this test fails: # test_trial_runner_pg.py:: # TrialRunnerPlacementGroupHeterogeneousTest:: # testResourceDeadlock next_trial = self.trial_executor.get_ready_trial() if next_trial is not None: # Must be able to start. assert _start_trial(next_trial) def _on_executor_error( self, trial, event_type: _ExecutorEventType, e: Union[RayTaskError, TuneError] ): error_msg = f"Trial {trial}: Error happened when processing {str(event_type)}." if self._fail_fast == self.RAISE: raise e else: if self._print_trial_errors: logger.error(error_msg, exc_info=e) self._process_trial_failure(trial, exception=e) def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None): return self.trial_executor.stop_trial( trial, error=bool(exception), exc=exception ) def _schedule_graceful_trial_stop(self, trial: Trial): self._schedule_trial_export(trial) self._schedule_trial_stop(trial) def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True): self.trial_executor.pause_trial(trial, should_checkpoint=should_checkpoint) def _schedule_trial_train(self, trial: Trial): self.trial_executor.continue_training(trial) def _schedule_trial_save( self, trial: Trial, storage: CheckpointStorage = CheckpointStorage.PERSISTENT, result: Optional[Dict] = None, ) -> Optional[_TrackedCheckpoint]: return self.trial_executor.save(trial, storage=storage, result=result) def _schedule_trial_export(self, trial: Trial): return self.trial_executor.export_trial_if_needed(trial) def __getstate__(self): state = super().__getstate__() state.pop("trial_executor") return state
class _TrialExecutorWrapper: """Wraps around TrialExecutor class, intercepts API calls and warns users of restricted API access. This is meant to facilitate restricting the current API exposure of TrialExecutor by TrialScheduler. """ def __init__( self, trial_executor: RayTrialExecutor, whitelist_attr: Optional[set] = None ): self._trial_executor = trial_executor self._whitelist_attr = whitelist_attr or set() def __getattr__(self, attr): if attr not in self._whitelist_attr: if log_once("restrict_accessing_trial_executor"): logger.warning( f"You are trying to access {attr} interface of " f"TrialExecutor in TrialScheduler, which is being " f"restricted. If you believe it is reasonable for " f"your scheduler to access this TrialExecutor API, " f"please reach out to Ray team on GitHub. A more " f"strict API access pattern would be enforced " f"starting 1.12.0" ) return getattr(self._trial_executor, attr) @DeveloperAPI class TrialRunnerWrapper: """Wraps around TrialRunner class, intercepts API calls and warns users of restricted API access. This is meant to facilitate restricting the current API exposure of TrialRunner by TrialScheduler. """ _EXECUTOR_ATTR = "trial_executor" def __init__( self, trial_runner: TrialRunner, trial_executor: Any, runner_whitelist_attr: Optional[set] = None, executor_whitelist_attr: Optional[set] = None, ): self._trial_runner = trial_runner self._trial_executor = _TrialExecutorWrapper( trial_executor, executor_whitelist_attr ) self._runner_whitelist_attr = runner_whitelist_attr or set() def __getattr__(self, attr): if attr == self._EXECUTOR_ATTR: return self._trial_executor if attr not in self._runner_whitelist_attr: if log_once("restrict_accessing_trial_runner"): logger.warning( f"You are trying to access {attr} interface of " f"TrialRunner in TrialScheduler, which is being " f"restricted. If you believe it is reasonable for " f"your scheduler to access this TrialRunner API, " f"please reach out to Ray team on GitHub. A more " f"strict API access pattern would be enforced " f"starting 1.12s.0" ) return getattr(self._trial_runner, attr) def _get_max_pending_trials(search_alg: SearchAlgorithm) -> int: max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto") if max_pending_trials != "auto": return int(max_pending_trials) # Else, auto detect. # Only BasicVariantGenerator supports > 1 pending trials. # This is because we don't want to generate too many trials # before we fit the searcher model. if not isinstance(search_alg, BasicVariantGenerator): return 1 # Use a minimum of 16 to trigger fast autoscaling # Scale up to at most the number of available cluster CPUs cluster_cpus = ray.cluster_resources().get("CPU", 1.0) max_pending_trials = max(16, int(cluster_cpus * 1.1)) if max_pending_trials > 128: logger.warning( f"The maximum number of pending trials has been " f"automatically set to the number of available " f"cluster CPUs, which is high " f"({max_pending_trials} CPUs/pending trials). " f"If you're running an experiment with a large number " f"of trials, this could lead to scheduling overhead. " f"In this case, consider setting the " f"`TUNE_MAX_PENDING_TRIALS_PG` environment variable " f"to the desired maximum number of concurrent trials." ) return max_pending_trials