Source code for ray.tune.schedulers.pbt

import copy
import json
import logging
import math
import os
import random
import shutil
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

from ray.air.constants import TRAINING_ITERATION
from ray.train import Checkpoint
from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
from ray.tune.error import TuneError
from ray.tune.experiment import Trial
from ray.tune.result import DEFAULT_METRIC
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.search import SearchGenerator
from ray.tune.search.sample import Domain, Function
from ray.tune.search.variant_generator import format_vars
from ray.tune.utils.util import SafeFallbackEncoder
from ray.util import PublicAPI
from ray.util.debug import log_once

if TYPE_CHECKING:
    from ray.tune.execution.tune_controller import TuneController

logger = logging.getLogger(__name__)


class _PBTTrialState:
    """Internal PBT state tracked per-trial."""

    def __init__(self, trial: Trial):
        self.orig_tag = trial.experiment_tag
        self.last_score = None
        self.last_checkpoint = None
        self.last_perturbation_time = 0
        self.last_train_time = 0  # Used for synchronous mode.
        self.last_result = None  # Used for synchronous mode.

    def __repr__(self) -> str:
        return str(
            (
                self.last_score,
                self.last_checkpoint,
                self.last_train_time,
                self.last_perturbation_time,
            )
        )


def _explore(
    config: Dict,
    mutations: Dict,
    resample_probability: float,
    perturbation_factors: Tuple[float],
    custom_explore_fn: Optional[Callable],
) -> Tuple[Dict, Dict]:
    """Return a perturbed config and string descriptors of the operations performed
    on the original config to produce the new config.

    Args:
        config: Original hyperparameter configuration.
        mutations: Specification of mutations to perform as documented
            in the PopulationBasedTraining scheduler.
        resample_probability: Probability of allowing resampling of a
            particular variable.
        perturbation_factors: Scaling factors to choose between when mutating
            a continuous hyperparameter.
        custom_explore_fn: Custom explore function applied after built-in
            config perturbations.

    Returns:
        new_config: New hyperparameter configuration (after random mutations).
        operations: Map of hyperparams -> strings describing mutation operations
            performed
    """
    operations = {}
    new_config = copy.deepcopy(config)
    for key, distribution in mutations.items():
        if isinstance(distribution, dict):
            # Handle nested hyperparameter configs by recursively perturbing them
            nested_new_config, nested_ops = _explore(
                config[key],
                mutations[key],
                resample_probability,
                perturbation_factors,
                custom_explore_fn=None,
            )
            new_config.update({key: nested_new_config})
            operations.update({key: nested_ops})
        elif isinstance(distribution, (list, tuple)):
            # Case 1: Hyperparameter resample distribution is a list/tuple
            if (
                random.random() < resample_probability
                or config[key] not in distribution
            ):
                # Resample a value from the list with `resample_probability`
                new_config[key] = random.choice(distribution)
                operations[key] = "resample"
            else:
                # Otherwise, perturb by shifting to the left or right of the list
                shift = random.choice([-1, 1])
                old_idx = distribution.index(config[key])
                new_idx = old_idx + shift
                new_idx = min(max(new_idx, 0), len(distribution) - 1)
                new_config[key] = distribution[new_idx]
                operations[key] = (
                    f"shift {'left' if shift == -1 else 'right'}"
                    f"{' (noop)' if old_idx == new_idx else ''}"
                )
        elif isinstance(distribution, (Domain, Callable)):
            # Case 2: Hyperparameter resample distribution is:
            # 1. a function (ex: lambda: np.random.uniform(0, 1))
            # 2. tune search Domain (ex: tune.uniform(0, 1))
            if random.random() < resample_probability:
                # Resample a value from the function/domain with `resample_probability`
                new_config[key] = (
                    distribution.sample(None)
                    if isinstance(distribution, Domain)
                    else distribution()
                )
                operations[key] = "resample"
            else:
                # Otherwise, perturb by multiplying the hyperparameter by one
                # of the `perturbation_factors`
                perturbation_factor = random.choice(perturbation_factors)
                new_config[key] = config[key] * perturbation_factor
                operations[key] = f"* {perturbation_factor}"
            if isinstance(config[key], int):
                # If this hyperparameter started out as an integer (ex: `batch_size`),
                # convert the new value back
                new_config[key] = int(new_config[key])
        else:
            raise ValueError(
                f"Unsupported hyperparameter distribution type: {type(distribution)}"
            )
    if custom_explore_fn:
        # The user can perform any additional hyperparameter exploration
        # via `custom_explore_fn`
        new_config = custom_explore_fn(new_config)
        assert new_config is not None, "Custom explore fn failed to return new config"
    return new_config, operations


def _make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str:
    """Appends perturbed params to the trial name to show in the console."""

    resolved_vars = {}
    for k in mutations.keys():
        resolved_vars[("config", k)] = config[k]
    return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))


def _fill_config(
    config: Dict, attr: str, search_space: Union[dict, list, tuple, Callable, Domain]
):
    """Add attr to config by sampling from search_space.

    This is a helper used to set initial hyperparameter values if the user doesn't
    specify them in the Tuner `param_space`.
    """
    if isinstance(search_space, Callable):
        config[attr] = search_space()
    elif isinstance(search_space, Domain):
        config[attr] = search_space.sample(None)
    elif isinstance(search_space, (list, tuple)):
        config[attr] = random.choice(search_space)
    elif isinstance(search_space, dict):
        config[attr] = {}
        for k, v in search_space.items():
            _fill_config(config[attr], k, v)


def _filter_mutated_params_from_config(
    config: Dict, hyperparam_mutations: Dict
) -> Dict:
    """Filter out hyperparameters from a config so that only parameters specified
    within hyperparam_mutations remain. This recursively filters nested configs.

    Example:
    >>> config = {
    ...     "a": {"b": 2, "c": 0, "d": {"e": 0.1}},
    ...     "f": {"g": 0.5},
    ... }
    >>> hyperparam_mutations = {
    ...     "a": {"b": [1, 2], "c": [-1, 0]},
    ... }
    >>> _filter_mutated_params_from_config(config, hyperparam_mutations) == {
    ...     "a": {"b": 2, "c": 0}
    ... }
    True

    Args:
        config: The config dict that we want to filter.
        hyperparam_mutations: A dict containing a subset of hyperparameters from
            config, used to filter the config.

    Returns:
        mutated_params: A copy of config containing only params specified in
            hyperparam_mutations
    """
    mutated_params = {}
    for param_name in config:
        if param_name not in hyperparam_mutations:
            continue

        if isinstance(config[param_name], dict):
            nested_params = _filter_mutated_params_from_config(
                config[param_name], hyperparam_mutations[param_name]
            )
            mutated_params[param_name] = nested_params
        else:
            mutated_params[param_name] = config[param_name]
    return mutated_params


[docs]@PublicAPI class PopulationBasedTraining(FIFOScheduler): """Implements the Population Based Training (PBT) algorithm. https://www.deepmind.com/blog/population-based-training-of-neural-networks PBT trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and a random mutation is applied to their hyperparameters in the hopes of outperforming the current top models. Unlike other hyperparameter search algorithms, PBT mutates hyperparameters during training time. This enables very fast hyperparameter discovery and also automatically discovers good annealing schedules. This Tune PBT implementation considers all trials added as part of the PBT population. If the number of trials exceeds the cluster capacity, they will be time-multiplexed as to balance training progress across the population. To run multiple trials, use `tune.TuneConfig(num_samples=<int>)`. In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in `pbt_global.txt` and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config] on each perturbation step. Args: time_attr: The training result attr to use for comparing time. Note that you can pass in something non-temporal such as `training_iteration` as a measure of progress, the only requirement is that the attribute should increase monotonically. metric: The training result objective value attribute. Stopping procedures will use this attribute. If None but a mode was passed, the `ray.tune.result.DEFAULT_METRIC` will be used per default. mode: One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. perturbation_interval: Models will be considered for perturbation at this interval of `time_attr`. Note that perturbation incurs checkpoint overhead, so you shouldn't set this to be too frequent. burn_in_period: Models will not be considered for perturbation before this interval of `time_attr` has passed. This guarantees that models are trained for at least a certain amount of time or timesteps before being perturbed. hyperparam_mutations: Hyperparams to mutate. The format is as follows: for each key, either a list, function, or a tune search space object (tune.loguniform, tune.uniform, etc.) can be provided. A list specifies an allowed set of categorical values. A function or tune search space object specifies the distribution of a continuous parameter. You must use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary tune.sample_from objects are not supported. A key can also hold a dict for nested hyperparameters. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. Tune will sample the search space provided by `hyperparam_mutations` for the initial hyperparameter values if the corresponding hyperparameters are not present in a trial's initial `config`. quantile_fraction: Parameters are transferred from the top `quantile_fraction` fraction of trials to the bottom `quantile_fraction` fraction. Needs to be between 0 and 0.5. Setting it to 0 essentially implies doing no exploitation at all. resample_probability: The probability of resampling from the original distribution when applying `hyperparam_mutations`. If not resampled, the value will be perturbed by a factor chosen from `perturbation_factors` if continuous, or changed to an adjacent value if discrete. perturbation_factors: Scaling factors to choose between when mutating a continuous hyperparameter. custom_explore_fn: You can also specify a custom exploration function. This function is invoked as `f(config)` after built-in perturbations from `hyperparam_mutations` are applied, and should return `config` updated as needed. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. log_config: Whether to log the ray config of each model to local_dir at each exploit. Allows config schedule to be reconstructed. require_attrs: Whether to require time_attr and metric to appear in result for every iteration. If True, error will be raised if these values are not present in trial result. synch: If False, will use asynchronous implementation of PBT. Trial perturbations occur every perturbation_interval for each trial independently. If True, will use synchronous implementation of PBT. Perturbations will occur only after all trials are synced at the same time_attr every perturbation_interval. Defaults to False. See Appendix A.1 here https://arxiv.org/pdf/1711.09846.pdf. .. code-block:: python import random from ray import tune from ray.tune.schedulers import PopulationBasedTraining pbt = PopulationBasedTraining( time_attr="training_iteration", metric="episode_reward_mean", mode="max", perturbation_interval=10, # every 10 `time_attr` units # (training_iterations in this case) hyperparam_mutations={ # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling # resets it to a value sampled from the lambda function. "factor_1": lambda: random.uniform(0.0, 20.0), # Alternatively, use tune search space primitives. # The search space for factor_1 is equivalent to factor_2. "factor_2": tune.uniform(0.0, 20.0), # Perturb factor3 by changing it to an adjacent value, e.g. # 10 -> 1 or 10 -> 100. Resampling will choose at random. "factor_3": [1, 10, 100, 1000, 10000], # Using tune.choice is NOT equivalent to the above. # factor_4 is treated as a continuous hyperparameter. "factor_4": tune.choice([1, 10, 100, 1000, 10000]), }) tuner = tune.Tuner( trainable, tune_config=tune.TuneConfig( scheduler=pbt, num_samples=8, ), ) tuner.fit() """ def __init__( self, time_attr: str = "time_total_s", metric: Optional[str] = None, mode: Optional[str] = None, perturbation_interval: float = 60.0, burn_in_period: float = 0.0, hyperparam_mutations: Dict[ str, Union[dict, list, tuple, Callable, Domain] ] = None, quantile_fraction: float = 0.25, resample_probability: float = 0.25, perturbation_factors: Tuple[float, float] = (1.2, 0.8), custom_explore_fn: Optional[Callable] = None, log_config: bool = True, require_attrs: bool = True, synch: bool = False, ): hyperparam_mutations = hyperparam_mutations or {} for value in hyperparam_mutations.values(): if not isinstance(value, (dict, list, tuple, Domain, Callable)): raise TypeError( "`hyperparam_mutation` values must be either " "a List, Tuple, Dict, a tune search space object, or " "a callable." ) if isinstance(value, Function): raise ValueError( "arbitrary tune.sample_from objects are not " "supported for `hyperparam_mutation` values." "You must use other built in primitives like" "tune.uniform, tune.loguniform, etc." ) if not hyperparam_mutations and not custom_explore_fn: raise TuneError( "You must specify at least one of `hyperparam_mutations` " "or `custom_explore_fn` to use PBT." ) if quantile_fraction > 0.5 or quantile_fraction < 0: raise ValueError( "You must set `quantile_fraction` to a value between 0 and" "0.5. Current value: '{}'".format(quantile_fraction) ) if perturbation_interval <= 0: raise ValueError( "perturbation_interval must be a positive number greater " "than 0. Current value: '{}'".format(perturbation_interval) ) if mode: assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." super().__init__() self._metric = metric self._mode = mode self._metric_op = None if self._mode == "max": self._metric_op = 1.0 elif self._mode == "min": self._metric_op = -1.0 self._time_attr = time_attr self._perturbation_interval = perturbation_interval self._burn_in_period = burn_in_period self._hyperparam_mutations = hyperparam_mutations self._quantile_fraction = quantile_fraction self._resample_probability = resample_probability self._perturbation_factors = perturbation_factors self._trial_state = {} self._custom_explore_fn = custom_explore_fn self._log_config = log_config self._require_attrs = require_attrs self._synch = synch self._next_perturbation_sync = max( self._perturbation_interval, self._burn_in_period, ) # Metrics self._num_checkpoints = 0 self._num_perturbations = 0 def set_search_properties( self, metric: Optional[str], mode: Optional[str], **spec ) -> bool: if self._metric and metric: return False if self._mode and mode: return False if metric: self._metric = metric if mode: self._mode = mode if self._mode == "max": self._metric_op = 1.0 elif self._mode == "min": self._metric_op = -1.0 if self._metric is None and self._mode: # If only a mode was passed, use anonymous metric self._metric = DEFAULT_METRIC return True def on_trial_add(self, tune_controller: "TuneController", trial: Trial): if tune_controller.search_alg is not None and isinstance( tune_controller.search_alg, SearchGenerator ): raise ValueError( "Search algorithms cannot be used with {} " "schedulers. Please remove {}.".format( self.__class__.__name__, tune_controller.search_alg ) ) if not self._metric or not self._metric_op: raise ValueError( "{} has been instantiated without a valid `metric` ({}) or " "`mode` ({}) parameter. Either pass these parameters when " "instantiating the scheduler, or pass them as parameters " "to `tune.TuneConfig()`".format( self.__class__.__name__, self._metric, self._mode ) ) checkpoint_config = trial.run_metadata.checkpoint_manager.checkpoint_config if ( checkpoint_config.num_to_keep and checkpoint_config.num_to_keep <= 2 and log_once("pbt_num_to_keep") ): warnings.warn( "Using `CheckpointConfig.num_to_keep <= 2` with PBT can lead to " "restoration problems when checkpoint are deleted too early for " "other trials to exploit them. If this happens, increase the value " "of `num_to_keep`." ) self._trial_state[trial] = _PBTTrialState(trial) for attr in self._hyperparam_mutations.keys(): if attr not in trial.config: if log_once(attr + "-missing"): logger.debug( "Cannot find {} in config. Using search " "space provided by hyperparam_mutations." ) # Add attr to trial's config by sampling search space from # hyperparam_mutations. _fill_config(trial.config, attr, self._hyperparam_mutations[attr]) # Make sure this attribute is added to CLI output. trial.evaluated_params[attr] = trial.config[attr] def on_trial_result( self, tune_controller: "TuneController", trial: Trial, result: Dict ) -> str: if self._time_attr not in result: time_missing_msg = ( "Cannot find time_attr {} " "in trial result {}. Make sure that this " "attribute is returned in the " "results of your Trainable.".format(self._time_attr, result) ) if self._require_attrs: raise RuntimeError( time_missing_msg + "If this error is expected, you can change this to " "a warning message by " "setting PBT(require_attrs=False)" ) else: if log_once("pbt-time_attr-error"): logger.warning(time_missing_msg) if self._metric not in result: metric_missing_msg = ( "Cannot find metric {} in trial result {}. " "Make sure that this attribute is returned " "in the " "results of your Trainable.".format(self._metric, result) ) if self._require_attrs: raise RuntimeError( metric_missing_msg + "If this error is expected, " "you can change this to a warning message by " "setting PBT(require_attrs=False)" ) else: if log_once("pbt-metric-error"): logger.warning(metric_missing_msg) if self._metric not in result or self._time_attr not in result: return TrialScheduler.CONTINUE time = result[self._time_attr] state = self._trial_state[trial] # Continue training if burn-in period has not been reached, yet. if time < self._burn_in_period: logger.debug(f"Still in burn-in period: {time} < {self._burn_in_period}") return TrialScheduler.CONTINUE # Continue training if perturbation interval has not been reached, yet. time_since_perturb = time - state.last_perturbation_time if time_since_perturb < self._perturbation_interval: logger.debug( f"Perturbation interval not reached: " f"{time_since_perturb} < {self._perturbation_interval}" ) return TrialScheduler.CONTINUE # avoid checkpoint overhead logger.debug(f"Updating trial state for trial {trial} at time {time}") self._save_trial_state(state, time, result, trial) if not self._synch: state.last_perturbation_time = time lower_quantile, upper_quantile = self._quantiles() decision = TrialScheduler.CONTINUE for other_trial in tune_controller.get_trials(): if other_trial.status in [Trial.PENDING, Trial.PAUSED]: decision = TrialScheduler.PAUSE break self._checkpoint_or_exploit( trial, tune_controller, upper_quantile, lower_quantile ) return TrialScheduler.NOOP if trial.status == Trial.PAUSED else decision else: # Synchronous mode. if any( self._trial_state[t].last_train_time < self._next_perturbation_sync and t != trial for t in tune_controller.get_live_trials() ): logger.debug( f"Sync: Other trials are not at perturb time, yet. " f"Pausing trial {trial} to wait." ) else: # All trials are synced at the same timestep. logger.debug("Sync: All trials are at perturb time.") lower_quantile, upper_quantile = self._quantiles() all_trials = tune_controller.get_trials() not_in_quantile = [] for t in all_trials: if t not in lower_quantile and t not in upper_quantile: not_in_quantile.append(t) logger.debug( "Trial statistics\n" f"Upper quantile: {upper_quantile}\n" f"Lower quantile: {lower_quantile}\n" f"Not in quantile: {not_in_quantile}" ) # Move upper quantile trials to beginning and lower quantile # to end. This ensures that checkpointing of strong trials # occurs before exploiting of weaker ones. all_trials = upper_quantile + not_in_quantile + lower_quantile for t in all_trials: logger.debug(f"Perturbing trial {t}") self._trial_state[t].last_perturbation_time = time self._checkpoint_or_exploit( t, tune_controller, upper_quantile, lower_quantile ) all_train_times = [ self._trial_state[t].last_train_time for t in tune_controller.get_trials() ] max_last_train_time = max(all_train_times) self._next_perturbation_sync = max( self._next_perturbation_sync + self._perturbation_interval, max_last_train_time, ) logger.debug(f"Next perturb at time {self._next_perturbation_sync}") # In sync mode we should pause all trials once result comes in. # Once a perturbation step happens for all trials, they should # still all be paused. # choose_trial_to_run will then pick the next trial to run out of # the paused trials. return ( TrialScheduler.NOOP if trial.status == Trial.PAUSED else TrialScheduler.PAUSE ) def _save_trial_state( self, state: _PBTTrialState, time: int, result: Dict, trial: Trial ): """Saves necessary trial information when result is received. Args: state: The state object for the trial. time: The current timestep of the trial. result: The trial's result dictionary. trial: The trial object. """ # This trial has reached its perturbation interval. # Record new state in the state object. score = self._metric_op * result[self._metric] state.last_score = score state.last_train_time = time state.last_result = result return score def _checkpoint_or_exploit( self, trial: Trial, tune_controller: "TuneController", upper_quantile: List[Trial], lower_quantile: List[Trial], ): """Checkpoint if in upper quantile, exploits if in lower.""" state = self._trial_state[trial] if trial in upper_quantile: # The trial last result is only updated after the scheduler # callback. So, we override with the current result. logger.debug(f"Trial {trial} is in upper quantile. Saving checkpoint.") if trial.status == Trial.PAUSED: if trial.temporary_state.saving_to and isinstance( trial.temporary_state.saving_to, _FutureTrainingResult ): logger.debug(f"Trial {trial} is still saving.") state.last_checkpoint = trial.temporary_state.saving_to else: # Paused trial will always have an in-memory checkpoint. logger.debug( f"Trial {trial} is paused. Use last available " f"checkpoint {trial.checkpoint}." ) state.last_checkpoint = trial.checkpoint else: logger.debug(f"Instructing {trial} to save.") state.last_checkpoint = tune_controller._schedule_trial_save( trial, result=state.last_result ) self._num_checkpoints += 1 else: state.last_checkpoint = None # not a top trial if trial in lower_quantile: trial_to_clone = random.choice(upper_quantile) assert trial is not trial_to_clone clone_state = self._trial_state[trial_to_clone] last_checkpoint = clone_state.last_checkpoint logger.debug( f"Trial {trial} is in lower quantile. " f"Exploiting trial {trial_to_clone}." ) if isinstance(last_checkpoint, _FutureTrainingResult): training_result = last_checkpoint.resolve() if training_result: clone_state.last_result = training_result.metrics clone_state.last_checkpoint = training_result.checkpoint last_checkpoint = clone_state.last_checkpoint else: logger.debug( "PBT-scheduled checkpoint save resolved to None. Trial " f"{trial_to_clone} didn't save any checkpoint before " f"and can't be exploited." ) last_checkpoint = None if not last_checkpoint: logger.info( f"[pbt]: no checkpoint for trial {trial_to_clone}." f" Skip exploit for Trial {trial}" ) return self._exploit(tune_controller, trial, trial_to_clone) def _log_config_on_step( self, trial_state: _PBTTrialState, new_state: _PBTTrialState, trial: Trial, trial_to_clone: Trial, new_config: Dict, ): """Logs transition during exploit/exploit step. For each step, logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config]. """ trial_name, trial_to_clone_name = (trial_state.orig_tag, new_state.orig_tag) trial_id = trial.trial_id trial_to_clone_id = trial_to_clone.trial_id trial_path = os.path.join( trial.local_experiment_path, "pbt_policy_" + trial_id + ".txt" ) trial_to_clone_path = os.path.join( trial_to_clone.local_dir, "pbt_policy_" + trial_to_clone_id + ".txt" ) policy = [ trial_name, trial_to_clone_name, trial.last_result.get(TRAINING_ITERATION, 0), trial_to_clone.last_result.get(TRAINING_ITERATION, 0), trial_to_clone.config, new_config, ] # Log to global file. with open( os.path.join(trial.local_experiment_path, "pbt_global.txt"), "a+" ) as f: print(json.dumps(policy, cls=SafeFallbackEncoder), file=f) # Overwrite state in target trial from trial_to_clone. if os.path.exists(trial_to_clone_path): shutil.copyfile(trial_to_clone_path, trial_path) # Log new exploit in target trial log. with open(trial_path, "a+") as f: f.write(json.dumps(policy, cls=SafeFallbackEncoder) + "\n") def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]: """Gets new config for trial by exploring trial_to_clone's config. Args: trial: The current trial that decided to exploit trial_to_clone. trial_to_clone: The top-performing trial with a hyperparameter config that the current trial will explore by perturbing. Returns: new_config: New hyperparameter configuration (after random mutations). operations: Map of hyperparams -> strings describing mutation operations performed """ return _explore( trial_to_clone.config, self._hyperparam_mutations, self._resample_probability, self._perturbation_factors, self._custom_explore_fn, ) def _summarize_hyperparam_changes( self, old_params: Dict, new_params: Dict, operations: Optional[Dict] = None, prefix: str = "", ) -> str: """Generates a summary of hyperparameter changes from a PBT "explore" step. Example: Given the following hyperparam_mutations: hyperparam_mutations = { "a": tune.uniform(0, 1), "b": list(range(5)), "c": { "d": tune.uniform(2, 3), "e": {"f": [-1, 0, 1]}, }, } This is an example summary output of the operations performed on old_params to get new_params: a : 0.5 --- (* 0.8) --> 0.4 b : 2 --- (resample) --> 4 c : d : 2.5 --- (* 1.2) --> 3.0 e : f : 0 --- (shift right) --> 1 The summary shows the old and new hyperparameter values, with the operation used to perturb labeled in between. If the operation for a certain hyperparameter is not provided, then the summary will just contain arrows without a label. (ex: a : 0.5 -----> 0.4) Args: old_params: Old values of hyperparameters that are perturbed to generate the new config new_params: The newly generated hyperparameter config from PBT exploration operations: Map of hyperparams -> string descriptors the operations performed to generate the values in `new_params` prefix: Helper argument to format nested dict hyperparam configs Returns: summary_str: The hyperparameter change summary to print/log. """ summary_str = "" if not old_params: return summary_str for param_name in old_params: old_val = old_params[param_name] assert param_name in new_params, ( "`old_params` and `new_params` " f"must both contain the key: '{param_name}'\n" f"old_params.keys() = {old_params.keys()}\n" f"new_params.keys() = {new_params.keys()}" ) new_val = new_params[param_name] summary_str += f"{prefix}{param_name} : " if isinstance(old_val, Dict): # Handle nested hyperparameters by recursively summarizing summary_str += "\n" nested_operations = operations.get(param_name, {}) summary_str += self._summarize_hyperparam_changes( old_val, new_val, operations=nested_operations, prefix=prefix + " " * 4, ) else: op = operations.get(param_name, None) if not op: arrow = "----->" else: arrow = f"--- ({op}) -->" summary_str += f"{old_val} {arrow} {new_val}\n" return summary_str def _exploit( self, tune_controller: "TuneController", trial: Trial, trial_to_clone: Trial, ): """Transfers perturbed state from trial_to_clone -> trial. If specified, also logs the updated hyperparam state. """ trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] class_name = self.__class__.__name__ logger.info( f"\n\n[{class_name}] [Exploit] Cloning trial " "{} (score = {:4f}) into trial {} (score = {:4f})\n".format( trial_to_clone.trial_id, new_state.last_score, trial.trial_id, trial_state.last_score, ) ) new_config, operations = self._get_new_config(trial, trial_to_clone) # Only log mutated hyperparameters and not entire config. old_params = _filter_mutated_params_from_config( trial_to_clone.config, self._hyperparam_mutations ) new_params = _filter_mutated_params_from_config( new_config, self._hyperparam_mutations ) explore_info_str = ( f"\n\n[{class_name}] [Explore] Perturbed the hyperparameter config of trial" f"{trial.trial_id}:\n" ) explore_info_str += ( self._summarize_hyperparam_changes(old_params, new_params, operations) or "No hyperparameters mutated." ) logger.info(explore_info_str) if self._log_config: self._log_config_on_step( trial_state, new_state, trial, trial_to_clone, new_config ) new_tag = _make_experiment_tag( trial_state.orig_tag, new_config, self._hyperparam_mutations ) if trial.status == Trial.PAUSED: # If trial is paused we update it with a new checkpoint. # When the trial is started again, the new checkpoint is used. if not self._synch: raise TuneError( "Trials should be paused here only if in " "synchronous mode. If you encounter this error" " please raise an issue on Ray Github." ) else: tune_controller.pause_trial(trial, should_checkpoint=False) trial.set_experiment_tag(new_tag) # Clone hyperparameters from the `trial_to_clone` trial.set_config(new_config) # Resume training from a shallow copy of `trial_to_clone`'s latest # checkpoint checkpoint_to_exploit: Checkpoint = copy.copy(new_state.last_checkpoint) trial.run_metadata.checkpoint_manager._latest_checkpoint_result = ( _TrainingResult( checkpoint=checkpoint_to_exploit, metrics=new_state.last_result ) ) self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time trial_state.last_train_time = new_state.last_train_time def _quantiles(self) -> Tuple[List[Trial], List[Trial]]: """Returns trials in the lower and upper `quantile` of the population. If there is not enough data to compute this, returns empty lists. """ trials = [] for trial, state in self._trial_state.items(): logger.debug("Trial {}, state {}".format(trial, state)) if trial.is_finished(): logger.debug("Trial {} is finished".format(trial)) if state.last_score is not None and not trial.is_finished(): trials.append(trial) trials.sort(key=lambda t: self._trial_state[t].last_score) if len(trials) <= 1: return [], [] else: num_trials_in_quantile = int( math.ceil(len(trials) * self._quantile_fraction) ) if num_trials_in_quantile > len(trials) / 2: num_trials_in_quantile = int(math.floor(len(trials) / 2)) return (trials[:num_trials_in_quantile], trials[-num_trials_in_quantile:])
[docs] def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]: """Ensures all trials get fair share of time (as defined by time_attr). This enables the PBT scheduler to support a greater number of concurrent trials than can fit in the cluster at any given time. """ candidates = [] for trial in tune_controller.get_trials(): if trial.status in [ Trial.PENDING, Trial.PAUSED, ]: if not self._synch: candidates.append(trial) elif ( self._trial_state[trial].last_train_time < self._next_perturbation_sync ): candidates.append(trial) candidates.sort(key=lambda trial: self._trial_state[trial].last_train_time) return candidates[0] if candidates else None
# Unit test only. TODO(xwjiang): Remove test-specific APIs. def reset_stats(self): self._num_perturbations = 0 self._num_checkpoints = 0 # Unit test only. TODO(xwjiang): Remove test-specific APIs. def last_scores(self, trials: List[Trial]) -> List[float]: scores = [] for trial in trials: state = self._trial_state[trial] if state.last_score is not None and not trial.is_finished(): scores.append(state.last_score) return scores def debug_string(self) -> str: return "PopulationBasedTraining: {} checkpoints, {} perturbs".format( self._num_checkpoints, self._num_perturbations )
[docs]@PublicAPI class PopulationBasedTrainingReplay(FIFOScheduler): """Replays a Population Based Training run. Population Based Training does not return a single hyperparameter configuration, but rather a schedule of configurations. For instance, PBT might discover that a larger learning rate leads to good results in the first training iterations, but that a smaller learning rate is preferable later. This scheduler enables replaying these parameter schedules from a finished PBT run. This requires that population based training has been run with ``log_config=True``, which is the default setting. The scheduler will only accept and train a single trial. It will start with the initial config of the existing trial and update the config according to the schedule. Args: policy_file: The PBT policy file. Usually this is stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt`` where ``xxx`` is the trial ID. Example: .. code-block:: python # Replaying a result from ray.tune.examples.pbt_convnet_example from ray import train, tune from ray.tune.examples.pbt_convnet_example import PytorchTrainable from ray.tune.schedulers import PopulationBasedTrainingReplay replay = PopulationBasedTrainingReplay( "~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt") tuner = tune.Tuner( PytorchTrainable, run_config=train.RunConfig( stop={"training_iteration": 100} ), tune_config=tune.TuneConfig( scheduler=replay, ), ) tuner.fit() """ def __init__(self, policy_file: str): policy_file = Path(policy_file).expanduser() if not policy_file.exists(): raise ValueError("Policy file not found: {}".format(policy_file.as_posix())) self.policy_file = policy_file.as_posix() # Find and read pbt policy file, potentially raise error initial_config, self._policy = self._load_policy(self.policy_file) self.experiment_tag = "replay_{}".format(os.path.basename(self.policy_file)) self.config = initial_config self.current_config = self.config self._trial = None self._current_step = 0 self._num_perturbations = 0 self._policy_iter = iter(self._policy) self._next_policy = next(self._policy_iter, None) def _load_policy(self, policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]: raw_policy = [] with open(policy_file, "rt") as fp: for row in fp.readlines(): try: parsed_row = json.loads(row) except json.JSONDecodeError: raise ValueError( "Could not read PBT policy file: {}.".format(policy_file) ) from None raw_policy.append(tuple(parsed_row)) # Loop through policy from end to start to obtain changepoints policy = [] last_new_tag = None last_old_conf = None for old_tag, new_tag, old_step, new_step, old_conf, new_conf in reversed( raw_policy ): if last_new_tag and old_tag != last_new_tag: # Tag chain ended. This means that previous changes were # overwritten by the last change and should be ignored. break last_new_tag = new_tag last_old_conf = old_conf policy.append((new_step, new_conf)) return last_old_conf, list(reversed(policy)) def on_trial_add(self, tune_controller: "TuneController", trial: Trial): if self._trial: raise ValueError( "More than one trial added to PBT replay run. This " "means the same schedule will be trained multiple " "times. Do you want to set `n_samples=1`?" ) self._trial = trial if self._trial.config and self._policy: logger.warning( "Trial was initialized with a config, which was overwritten. " "Did you start the PBT replay with a `config` parameter?" ) elif self._trial.config and not self._policy: # Only train with initial policy self.config = self._trial.config elif not self._trial.config and not self._policy: raise ValueError( "No replay policy found and trial initialized without a " "valid config. Either pass a `config` argument to `tune.Tuner()`" "or consider not using PBT replay for this run." ) self._trial.set_config(self.config) def on_trial_result( self, tune_controller: "TuneController", trial: Trial, result: Dict ) -> str: if TRAINING_ITERATION not in result: # No time reported return TrialScheduler.CONTINUE if not self._next_policy: # No more changes in the config return TrialScheduler.CONTINUE step = result[TRAINING_ITERATION] self._current_step = step change_at, new_config = self._next_policy if step < change_at: # Don't change the policy just yet return TrialScheduler.CONTINUE logger.info( "Population Based Training replay is now at step {}. " "Configuration will be changed to {}.".format(step, new_config) ) result = tune_controller._schedule_trial_save(trial, result=result) training_result = result.resolve() trial.run_metadata.checkpoint_manager._latest_checkpoint_result = ( training_result ) new_tag = _make_experiment_tag(self.experiment_tag, new_config, new_config) tune_controller.pause_trial(trial, should_checkpoint=False) trial.set_experiment_tag(new_tag) trial.set_config(new_config) self.current_config = new_config self._num_perturbations += 1 self._next_policy = next(self._policy_iter, None) return TrialScheduler.NOOP def debug_string(self) -> str: return "PopulationBasedTraining replay: Step {}, perturb {}".format( self._current_step, self._num_perturbations )