# coding: utf-8
from abc import abstractmethod
import logging
from typing import Dict, List, Optional, Union
from ray.exceptions import RayTaskError
from ray.tune import TuneError
from ray.util.annotations import DeveloperAPI
from ray.tune.trial import Trial, _TuneCheckpoint
logger = logging.getLogger(__name__)
# Signals when a class is directly inherited from TrialExecutor.
# A warning is printed to inform users of TrialExecutor deprecation.
class _WarnOnDirectInheritanceMeta(type):
def __new__(mcls, name, bases, module, **kwargs):
if (
name
not in (
"RayTrialExecutor",
"_MockTrialExecutor",
"TrialExecutor",
)
and "TrialExecutor" in tuple(base.__name__ for base in bases)
):
raise DeprecationWarning(
f"{name} inherits from TrialExecutor, which is being "
"deprecated. "
"RFC: https://github.com/ray-project/ray/issues/17593. "
"Please reach out on the Ray Github if you have any concerns."
)
cls = super().__new__(mcls, name, bases, module, **kwargs)
return cls
[docs]@DeveloperAPI
class TrialExecutor(metaclass=_WarnOnDirectInheritanceMeta):
"""Module for interacting with remote trainables.
Manages platform-specific details such as resource handling
and starting/stopping trials.
"""
def __init__(self):
"""Initializes a new TrialExecutor."""
self._cached_trial_state = {}
self._trials_to_cache = set()
[docs] def set_status(self, trial: Trial, status: str) -> None:
"""Sets status and checkpoints metadata if needed.
Only checkpoints metadata if trial status is a terminal condition.
PENDING, PAUSED, and RUNNING switches have checkpoints taken care of
in the TrialRunner.
Args:
trial: Trial to checkpoint.
status: Status to set trial to.
"""
if trial.status == status:
logger.debug("Trial %s: Status %s unchanged.", trial, trial.status)
else:
logger.debug(
"Trial %s: Changing status from %s to %s.", trial, trial.status, status
)
trial.set_status(status)
if status in [Trial.TERMINATED, Trial.ERROR]:
self._trials_to_cache.add(trial)
def mark_trial_to_checkpoint(self, trial: Trial) -> None:
self._trials_to_cache.add(trial)
[docs] def get_checkpoints(self) -> Dict[str, str]:
"""Returns a copy of mapping of the trial ID to pickled metadata."""
for trial in self._trials_to_cache:
self._cached_trial_state[trial.trial_id] = trial.get_json_state()
self._trials_to_cache.clear()
return self._cached_trial_state
[docs] @abstractmethod
def start_trial(self, trial: Trial) -> bool:
"""Starts the trial restoring from checkpoint if checkpoint is provided.
Args:
trial: Trial to be started.
Returns:
True if trial started successfully, False otherwise.
"""
pass
[docs] @abstractmethod
def stop_trial(
self,
trial: Trial,
error: bool = False,
exc: Optional[Union[TuneError, RayTaskError]] = None,
) -> None:
"""Stops the trial.
Stops this trial, releasing all allocating resources.
If stopping the trial fails, the run will be marked as terminated
in error, but no exception will be thrown.
Args:
error: Whether to mark this trial as terminated in error.
exc: Optional exception.
"""
pass
[docs] def continue_training(self, trial: Trial) -> None:
"""Continues the training of this trial."""
pass
[docs] def pause_trial(self, trial: Trial) -> None:
"""Pauses the trial.
We want to release resources (specifically GPUs) when pausing an
experiment. This results in PAUSED state that similar to TERMINATED.
"""
assert trial.status == Trial.RUNNING, trial.status
try:
self.save(trial, _TuneCheckpoint.MEMORY)
self.stop_trial(trial)
self.set_status(trial, Trial.PAUSED)
except Exception:
logger.exception("Error pausing runner.")
self.set_status(trial, Trial.ERROR)
[docs] @abstractmethod
def reset_trial(
self, trial: Trial, new_config: Dict, new_experiment_tag: str
) -> bool:
"""Tries to invoke `Trainable.reset()` to reset trial.
Args:
trial: Trial to be reset.
new_config: New configuration for Trial
trainable.
new_experiment_tag: New experiment name
for trial.
Returns:
True if `reset` is successful else False.
"""
pass
[docs] def on_step_begin(self, trials: List[Trial]) -> None:
"""A hook called before running one step of the trial event loop.
Args:
trials: The list of trials. Note, refrain from
providing TrialRunner directly here.
"""
pass
[docs] def on_step_end(self, trials: List[Trial]) -> None:
"""A hook called after running one step of the trial event loop.
Args:
trials: The list of trials. Note, refrain from
providing TrialRunner directly here.
"""
pass
def force_reconcilation_on_next_step_end(self) -> None:
pass
[docs] @abstractmethod
def debug_string(self) -> str:
"""Returns a human readable message for printing to the console."""
pass
[docs] @abstractmethod
def restore(self, trial: Trial) -> None:
"""Restores training state from a checkpoint.
If checkpoint is None, try to restore from trial.checkpoint.
If restoring fails, the trial status will be set to ERROR.
Args:
trial: Trial to be restored.
Returns:
False if error occurred, otherwise return True.
"""
pass
[docs] @abstractmethod
def save(
self,
trial: Trial,
storage: str = _TuneCheckpoint.PERSISTENT,
result: Optional[Dict] = None,
) -> _TuneCheckpoint:
"""Saves training state of this trial to a checkpoint.
If result is None, this trial's last result will be used.
Args:
trial: The state of this trial to be saved.
storage: Where to store the checkpoint. Defaults to
PERSISTENT.
result: The state of this trial as a dictionary to be saved.
Returns:
A Checkpoint object.
"""
pass
[docs] @abstractmethod
def export_trial_if_needed(self, trial: Trial) -> Dict:
"""Exports model of this trial based on trial.export_formats.
Args:
trial: The state of this trial to be saved.
Returns:
A dict that maps ExportFormats to successfully exported models.
"""
pass
[docs] def has_gpus(self) -> bool:
"""Returns True if GPUs are detected on the cluster."""
return False
[docs] def cleanup(self, trials: List[Trial]) -> None:
"""Ensures that trials are cleaned up after stopping.
Args:
trials: The list of trials. Note, refrain from
providing TrialRunner directly here.
"""
pass
[docs] def set_max_pending_trials(self, max_pending: int) -> None:
"""Set the maximum number of allowed pending trials."""
pass