import argparse
import collections
import datetime
import logging
import math
import numbers
import os
import sys
import textwrap
import time
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import ray
from ray._private.dict import flatten_dict, unflattened_lookup
from ray._private.thirdparty.tabulate.tabulate import (
DataRow,
Line,
TableFormat,
tabulate,
)
from ray.air._internal.usage import AirEntrypoint
from ray.air.constants import TRAINING_ITERATION
from ray.train import Checkpoint
from ray.tune.callback import Callback
from ray.tune.experiment.trial import Trial
from ray.tune.result import (
AUTO_RESULT_KEYS,
EPISODE_REWARD_MEAN,
MEAN_ACCURACY,
MEAN_LOSS,
TIME_TOTAL_S,
TIMESTEPS_TOTAL,
)
from ray.tune.search.sample import Domain
from ray.tune.utils.log import Verbosity
try:
import rich
import rich.layout
import rich.live
except ImportError:
rich = None
logger = logging.getLogger(__name__)
# defines the mapping of the key in result and the key to be printed in table.
# Note this is ordered!
DEFAULT_COLUMNS = collections.OrderedDict(
{
MEAN_ACCURACY: "acc",
MEAN_LOSS: "loss",
TRAINING_ITERATION: "iter",
TIME_TOTAL_S: "total time (s)",
TIMESTEPS_TOTAL: "ts",
EPISODE_REWARD_MEAN: "reward",
}
)
# These keys are blacklisted for printing out training/tuning intermediate/final result!
BLACKLISTED_KEYS = {
"config",
"date",
"done",
"hostname",
"iterations_since_restore",
"node_ip",
"pid",
"time_since_restore",
"timestamp",
"trial_id",
"experiment_tag",
"should_checkpoint",
"_report_on", # LIGHTNING_REPORT_STAGE_KEY
}
VALID_SUMMARY_TYPES = {
int,
float,
np.float32,
np.float64,
np.int32,
np.int64,
type(None),
}
# The order of summarizing trials.
ORDER = [
Trial.RUNNING,
Trial.TERMINATED,
Trial.PAUSED,
Trial.PENDING,
Trial.ERROR,
]
class AirVerbosity(IntEnum):
SILENT = 0
DEFAULT = 1
VERBOSE = 2
def __repr__(self):
return str(self.value)
IS_NOTEBOOK = ray.widgets.util.in_notebook()
def get_air_verbosity(
verbose: Union[int, AirVerbosity, Verbosity]
) -> Optional[AirVerbosity]:
if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
return None
if isinstance(verbose, AirVerbosity):
return verbose
verbose_int = verbose if isinstance(verbose, int) else verbose.value
# Verbosity 2 and 3 both map to AirVerbosity 2
verbose_int = min(2, verbose_int)
return AirVerbosity(verbose_int)
def _infer_params(config: Dict[str, Any]) -> List[str]:
params = []
flat_config = flatten_dict(config)
for key, val in flat_config.items():
if isinstance(val, Domain):
params.append(key)
# Grid search is a special named field. Because we flattened
# the whole config, we look it up per string
if key.endswith("/grid_search"):
# Truncate `/grid_search`
params.append(key[:-12])
return params
def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
"""Get strings representing the current and elapsed time.
Args:
start_time: POSIX timestamp of the start of the tune run
current_time: POSIX timestamp giving the current time
Returns:
Current time and elapsed time for the current run
"""
current_time_dt = datetime.datetime.fromtimestamp(current_time)
start_time_dt = datetime.datetime.fromtimestamp(start_time)
delta: datetime.timedelta = current_time_dt - start_time_dt
rest = delta.total_seconds()
days = int(rest // (60 * 60 * 24))
rest -= days * (60 * 60 * 24)
hours = int(rest // (60 * 60))
rest -= hours * (60 * 60)
minutes = int(rest // 60)
seconds = int(rest - minutes * 60)
running_for_str = ""
if days > 0:
running_for_str += f"{days:d}d "
if hours > 0 or running_for_str:
running_for_str += f"{hours:d}hr "
if minutes > 0 or running_for_str:
running_for_str += f"{minutes:d}min "
running_for_str += f"{seconds:d}s"
return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
trials_by_state = collections.defaultdict(list)
for t in trials:
trials_by_state[t.status].append(t)
return trials_by_state
def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
return [t for t in trials if t.error_file]
def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
"""Try to infer the metrics to print out.
By default, only the first 4 meaningful metrics in `last_result` will be
inferred as user implied metrics.
"""
# Using OrderedDict for OrderedSet.
result = collections.OrderedDict()
for t in trials:
if not t.last_result:
continue
for metric, value in t.last_result.items():
if metric not in DEFAULT_COLUMNS:
if metric not in AUTO_RESULT_KEYS:
if type(value) in VALID_SUMMARY_TYPES:
result[metric] = "" # not important
if len(result) >= limit:
return list(result.keys())
return list(result.keys())
def _current_best_trial(
trials: List[Trial], metric: Optional[str], mode: Optional[str]
) -> Tuple[Optional[Trial], Optional[str]]:
"""
Returns the best trial and the metric key. If anything is empty or None,
returns a trivial result of None, None.
Args:
trials: List of trials.
metric: Metric that trials are being ranked.
mode: One of "min" or "max".
Returns:
Best trial and the metric key.
"""
if not trials or not metric or not mode:
return None, None
metric_op = 1.0 if mode == "max" else -1.0
best_metric = float("-inf")
best_trial = None
for t in trials:
if not t.last_result:
continue
metric_value = unflattened_lookup(metric, t.last_result, default=None)
if pd.isnull(metric_value):
continue
if not best_trial or metric_value * metric_op > best_metric:
best_metric = metric_value * metric_op
best_trial = t
return best_trial, metric
@dataclass
class _PerStatusTrialTableData:
trial_infos: List[List[str]]
more_info: str
@dataclass
class _TrialTableData:
header: List[str]
data: List[_PerStatusTrialTableData]
def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
"""Abbreviate a string representation of an object to `max_len` characters.
For numbers, booleans and None, the original value will be returned for
correct rendering in the table formatting tool.
Args:
value: Object to be represented as a string.
max_len: Maximum return string length.
"""
if value is None or isinstance(value, (int, float, numbers.Number, bool)):
return value
string = str(value)
if len(string) <= max_len:
return string
if wrap:
# Maximum two rows.
# Todo: Make this configurable in the refactor
if len(value) > max_len * 2:
value = "..." + string[(3 - (max_len * 2)) :]
wrapped = textwrap.wrap(value, width=max_len)
return "\n".join(wrapped)
result = "..." + string[(3 - max_len) :]
return result
def _get_trial_info(
trial: Trial, param_keys: List[str], metric_keys: List[str]
) -> List[str]:
"""Returns the following information about a trial:
name | status | metrics...
Args:
trial: Trial to get information for.
param_keys: Names of parameters to include.
metric_keys: Names of metrics to include.
"""
result = trial.last_result
trial_info = [str(trial), trial.status]
# params
trial_info.extend(
[
_max_len(
unflattened_lookup(param, trial.config, default=None),
)
for param in param_keys
]
)
# metrics
trial_info.extend(
[
_max_len(
unflattened_lookup(metric, result, default=None),
)
for metric in metric_keys
]
)
return trial_info
def _get_trial_table_data_per_status(
status: str,
trials: List[Trial],
param_keys: List[str],
metric_keys: List[str],
force_max_rows: bool = False,
) -> Optional[_PerStatusTrialTableData]:
"""Gather all information of trials pertained to one `status`.
Args:
status: The trial status of interest.
trials: all the trials of that status.
param_keys: *Ordered* list of parameters to be displayed in the table.
metric_keys: *Ordered* list of metrics to be displayed in the table.
Including both default and user defined.
force_max_rows: Whether or not to enforce a max row number for this status.
If True, only a max of `5` rows will be shown.
Returns:
All information of trials pertained to the `status`.
"""
# TODO: configure it.
max_row = 5 if force_max_rows else math.inf
if not trials:
return None
trial_infos = list()
more_info = None
for t in trials:
if len(trial_infos) >= max_row:
remaining = len(trials) - max_row
more_info = f"{remaining} more {status}"
break
trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
return _PerStatusTrialTableData(trial_infos, more_info)
def _get_trial_table_data(
trials: List[Trial],
param_keys: List[str],
metric_keys: List[str],
all_rows: bool = False,
wrap_headers: bool = False,
) -> _TrialTableData:
"""Generate a table showing the current progress of tuning trials.
Args:
trials: List of trials for which progress is to be shown.
param_keys: Ordered list of parameters to be displayed in the table.
metric_keys: Ordered list of metrics to be displayed in the table.
Including both default and user defined.
Will only be shown if at least one trial is having the key.
all_rows: Force to show all rows.
wrap_headers: If True, header columns can be wrapped with ``\n``.
Returns:
Trial table data, including header and trial table per each status.
"""
# TODO: configure
max_trial_num_to_show = 20
max_column_length = 20
trials_by_state = _get_trials_by_state(trials)
# get the right metric to show.
metric_keys = [
k
for k in metric_keys
if any(
unflattened_lookup(k, t.last_result, default=None) is not None
for t in trials
)
]
# get header from metric keys
formatted_metric_columns = [
_max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
]
formatted_param_columns = [
_max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
]
metric_header = [
DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
for metric, formatted in zip(metric_keys, formatted_metric_columns)
]
param_header = formatted_param_columns
# Map to the abbreviated version if necessary.
header = ["Trial name", "status"] + param_header + metric_header
trial_data = list()
for t_status in ORDER:
trial_data_per_status = _get_trial_table_data_per_status(
t_status,
trials_by_state[t_status],
param_keys=param_keys,
metric_keys=metric_keys,
force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
)
if trial_data_per_status:
trial_data.append(trial_data_per_status)
return _TrialTableData(header, trial_data)
def _best_trial_str(
trial: Trial,
metric: str,
):
"""Returns a readable message stating the current best trial."""
# returns something like
# Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
val = unflattened_lookup(metric, trial.last_result, default=None)
config = trial.last_result.get("config", {})
parameter_columns = list(config.keys())
params = {p: unflattened_lookup(p, config) for p in parameter_columns}
return (
f"Current best trial: {trial.trial_id} with {metric}={val} and "
f"params={params}"
)
def _render_table_item(
key: str, item: Any, prefix: str = ""
) -> Iterable[Tuple[str, str]]:
key = prefix + key
if isinstance(item, argparse.Namespace):
item = item.__dict__
if isinstance(item, float):
# tabulate does not work well with mixed-type columns, so we format
# numbers ourselves.
yield key, f"{item:.5f}".rstrip("0")
elif isinstance(item, dict):
flattened = flatten_dict(item)
for k, v in sorted(flattened.items()):
yield key + "/" + str(k), _max_len(v)
else:
yield key, _max_len(item, 20)
def _get_dict_as_table_data(
data: Dict,
include: Optional[Collection] = None,
exclude: Optional[Collection] = None,
upper_keys: Optional[Collection] = None,
):
"""Get ``data`` dict as table rows.
If specified, excluded keys are removed. Excluded keys can either be
fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
(e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
needed, we can revisit the logic at a later point.
The same is true for included keys. If a top-level key is included (e.g. ``foo``)
then all sub keys will be included, too, except if they are excluded.
If keys are both excluded and included, exclusion takes precedence. Thus, if
``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
"""
include = include or set()
exclude = exclude or set()
upper_keys = upper_keys or set()
upper = []
lower = []
for key, value in sorted(data.items()):
# Exclude top-level keys
if key in exclude:
continue
for k, v in _render_table_item(str(key), value):
# k is now the full subkey, e.g. config/nested/key
# We can exclude the full key
if k in exclude:
continue
# If we specify includes, top-level includes should take precedence
# (e.g. if `config` is in include, include config always).
if include and key not in include and k not in include:
continue
if key in upper_keys:
upper.append([k, v])
else:
lower.append([k, v])
if not upper:
return lower
elif not lower:
return upper
else:
return upper + lower
if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
# Copied/adjusted from tabulate
AIR_TABULATE_TABLEFMT = TableFormat(
lineabove=Line("╭", "─", "─", "╮"),
linebelowheader=Line("├", "─", "─", "┤"),
linebetweenrows=None,
linebelow=Line("╰", "─", "─", "╯"),
headerrow=DataRow("│", " ", "│"),
datarow=DataRow("│", " ", "│"),
padding=1,
with_header_hide=None,
)
else:
# For non-utf output, use ascii-compatible characters.
# This prevents errors e.g. when legacy windows encoding is used.
AIR_TABULATE_TABLEFMT = TableFormat(
lineabove=Line("+", "-", "-", "+"),
linebelowheader=Line("+", "-", "-", "+"),
linebetweenrows=None,
linebelow=Line("+", "-", "-", "+"),
headerrow=DataRow("|", " ", "|"),
datarow=DataRow("|", " ", "|"),
padding=1,
with_header_hide=None,
)
def _print_dict_as_table(
data: Dict,
header: Optional[str] = None,
include: Optional[Collection[str]] = None,
exclude: Optional[Collection[str]] = None,
division: Optional[Collection[str]] = None,
):
table_data = _get_dict_as_table_data(
data=data, include=include, exclude=exclude, upper_keys=division
)
headers = [header, ""] if header else []
if not table_data:
return
print(
tabulate(
table_data,
headers=headers,
colalign=("left", "right"),
tablefmt=AIR_TABULATE_TABLEFMT,
)
)
[docs]
class ProgressReporter(Callback):
"""Periodically prints out status update."""
# TODO: Make this configurable
_heartbeat_freq = 30 # every 30 sec
# to be updated by subclasses.
_heartbeat_threshold = None
_start_end_verbosity = None
_intermediate_result_verbosity = None
_addressing_tmpl = None
def __init__(
self,
verbosity: AirVerbosity,
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
):
"""
Args:
verbosity: AirVerbosity level.
"""
self._verbosity = verbosity
self._start_time = time.time()
self._last_heartbeat_time = float("-inf")
self._start_time = time.time()
self._progress_metrics = progress_metrics
self._trial_last_printed_results = {}
self._in_block = None
@property
def verbosity(self) -> AirVerbosity:
return self._verbosity
def setup(
self,
start_time: Optional[float] = None,
**kwargs,
):
self._start_time = start_time
def _start_block(self, indicator: Any):
if self._in_block != indicator:
self._end_block()
self._in_block = indicator
def _end_block(self):
if self._in_block:
print("")
self._in_block = None
def on_experiment_end(self, trials: List["Trial"], **info):
self._end_block()
def experiment_started(
self,
experiment_name: str,
experiment_path: str,
searcher_str: str,
scheduler_str: str,
total_num_samples: int,
tensorboard_path: Optional[str] = None,
**kwargs,
):
self._start_block("exp_start")
print(f"\nView detailed results here: {experiment_path}")
if tensorboard_path:
print(
f"To visualize your results with TensorBoard, run: "
f"`tensorboard --logdir {tensorboard_path}`"
)
@property
def _time_heartbeat_str(self):
current_time_str, running_time_str = _get_time_str(
self._start_time, time.time()
)
return (
f"Current time: {current_time_str}. Total running time: " + running_time_str
)
def print_heartbeat(self, trials, *args, force: bool = False):
if self._verbosity < self._heartbeat_threshold:
return
if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
self._print_heartbeat(trials, *args, force=force)
self._last_heartbeat_time = time.time()
def _print_heartbeat(self, trials, *args, force: bool = False):
raise NotImplementedError
def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
"""Only print result if a different result has been reported, or force=True"""
result = result or trial.last_result
last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
this_iter = result.get(TRAINING_ITERATION, 0)
if this_iter != last_result_iter or force:
_print_dict_as_table(
result,
header=f"{self._addressing_tmpl.format(trial)} result",
include=self._progress_metrics,
exclude=BLACKLISTED_KEYS,
division=AUTO_RESULT_KEYS,
)
self._trial_last_printed_results[trial.trial_id] = this_iter
def _print_config(self, trial):
_print_dict_as_table(
trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
)
def on_trial_result(
self,
iteration: int,
trials: List[Trial],
trial: Trial,
result: Dict,
**info,
):
if self.verbosity < self._intermediate_result_verbosity:
return
self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
print(
f"{self._addressing_tmpl.format(trial)} "
f"finished iteration {result[TRAINING_ITERATION]} "
f"at {curr_time_str}. Total running time: " + running_time_str
)
self._print_result(trial, result)
def on_trial_complete(
self, iteration: int, trials: List[Trial], trial: Trial, **info
):
if self.verbosity < self._start_end_verbosity:
return
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
finished_iter = 0
if trial.last_result and TRAINING_ITERATION in trial.last_result:
finished_iter = trial.last_result[TRAINING_ITERATION]
self._start_block(f"trial_{trial}_complete")
print(
f"{self._addressing_tmpl.format(trial)} "
f"completed after {finished_iter} iterations "
f"at {curr_time_str}. Total running time: " + running_time_str
)
self._print_result(trial)
def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
finished_iter = 0
if trial.last_result and TRAINING_ITERATION in trial.last_result:
finished_iter = trial.last_result[TRAINING_ITERATION]
self._start_block(f"trial_{trial}_error")
print(
f"{self._addressing_tmpl.format(trial)} "
f"errored after {finished_iter} iterations "
f"at {curr_time_str}. Total running time: {running_time_str}\n"
f"Error file: {trial.error_file}"
)
self._print_result(trial)
def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
def on_checkpoint(
self,
iteration: int,
trials: List[Trial],
trial: Trial,
checkpoint: Checkpoint,
**info,
):
if self._verbosity < self._intermediate_result_verbosity:
return
# don't think this is supposed to happen but just to be safe.
saved_iter = "?"
if trial.last_result and TRAINING_ITERATION in trial.last_result:
saved_iter = trial.last_result[TRAINING_ITERATION]
self._start_block(f"trial_{trial}_result_{saved_iter}")
loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
print(
f"{self._addressing_tmpl.format(trial)} "
f"saved a checkpoint for iteration {saved_iter} "
f"at: {loc}"
)
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
if self.verbosity < self._start_end_verbosity:
return
has_config = bool(trial.config)
self._start_block(f"trial_{trial}_start")
if has_config:
print(
f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
)
self._print_config(trial)
else:
print(
f"{self._addressing_tmpl.format(trial)} "
f"started without custom configuration."
)
def _detect_reporter(
verbosity: AirVerbosity,
num_samples: int,
entrypoint: Optional[AirEntrypoint] = None,
metric: Optional[str] = None,
mode: Optional[str] = None,
config: Optional[Dict] = None,
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
):
if entrypoint in {
AirEntrypoint.TUNE_RUN,
AirEntrypoint.TUNE_RUN_EXPERIMENTS,
AirEntrypoint.TUNER,
}:
reporter = TuneTerminalReporter(
verbosity,
num_samples=num_samples,
metric=metric,
mode=mode,
config=config,
progress_metrics=progress_metrics,
)
else:
reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
return reporter
[docs]
class TuneReporterBase(ProgressReporter):
_heartbeat_threshold = AirVerbosity.DEFAULT
_wrap_headers = False
_intermediate_result_verbosity = AirVerbosity.VERBOSE
_start_end_verbosity = AirVerbosity.DEFAULT
_addressing_tmpl = "Trial {}"
def __init__(
self,
verbosity: AirVerbosity,
num_samples: int = 0,
metric: Optional[str] = None,
mode: Optional[str] = None,
config: Optional[Dict] = None,
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
):
self._num_samples = num_samples
self._metric = metric
self._mode = mode
# will be populated when first result comes in.
self._inferred_metric = None
self._inferred_params = _infer_params(config or {})
super(TuneReporterBase, self).__init__(
verbosity=verbosity, progress_metrics=progress_metrics
)
def setup(
self,
start_time: Optional[float] = None,
total_samples: Optional[int] = None,
**kwargs,
):
super().setup(start_time=start_time)
self._num_samples = total_samples
def _get_overall_trial_progress_str(self, trials):
result = " | ".join(
[
f"{len(trials)} {status}"
for status, trials in _get_trials_by_state(trials).items()
]
)
return f"Trial status: {result}"
# TODO: Return a more structured type to share code with Jupyter flow.
def _get_heartbeat(
self, trials, *sys_args, force_full_output: bool = False
) -> Tuple[List[str], _TrialTableData]:
result = list()
# Trial status: 1 RUNNING | 7 PENDING
result.append(self._get_overall_trial_progress_str(trials))
# Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
result.append(self._time_heartbeat_str)
# Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
result.extend(sys_args)
# Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
current_best_trial, metric = _current_best_trial(
trials, self._metric, self._mode
)
if current_best_trial:
result.append(_best_trial_str(current_best_trial, metric))
# Now populating the trial table data.
if not self._inferred_metric:
# try inferring again.
self._inferred_metric = _infer_user_metrics(trials)
all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
trial_table_data = _get_trial_table_data(
trials,
param_keys=self._inferred_params,
metric_keys=all_metrics,
all_rows=force_full_output,
wrap_headers=self._wrap_headers,
)
return result, trial_table_data
def _print_heartbeat(self, trials, *sys_args, force: bool = False):
raise NotImplementedError
[docs]
class TuneTerminalReporter(TuneReporterBase):
def experiment_started(
self,
experiment_name: str,
experiment_path: str,
searcher_str: str,
scheduler_str: str,
total_num_samples: int,
tensorboard_path: Optional[str] = None,
**kwargs,
):
if total_num_samples > sys.maxsize:
total_num_samples_str = "infinite"
else:
total_num_samples_str = str(total_num_samples)
print(
tabulate(
[
["Search algorithm", searcher_str],
["Scheduler", scheduler_str],
["Number of trials", total_num_samples_str],
],
headers=["Configuration for experiment", experiment_name],
tablefmt=AIR_TABULATE_TABLEFMT,
)
)
super().experiment_started(
experiment_name=experiment_name,
experiment_path=experiment_path,
searcher_str=searcher_str,
scheduler_str=scheduler_str,
total_num_samples=total_num_samples,
tensorboard_path=tensorboard_path,
**kwargs,
)
def _print_heartbeat(self, trials, *sys_args, force: bool = False):
if self._verbosity < self._heartbeat_threshold and not force:
return
heartbeat_strs, table_data = self._get_heartbeat(
trials, *sys_args, force_full_output=force
)
self._start_block("heartbeat")
for s in heartbeat_strs:
print(s)
# now print the table using Tabulate
more_infos = []
all_data = []
fail_header = table_data.header
for sub_table in table_data.data:
all_data.extend(sub_table.trial_infos)
if sub_table.more_info:
more_infos.append(sub_table.more_info)
print(
tabulate(
all_data,
headers=fail_header,
tablefmt=AIR_TABULATE_TABLEFMT,
showindex=False,
)
)
if more_infos:
print(", ".join(more_infos))
if not force:
# Only print error table at end of training
return
trials_with_error = _get_trials_with_error(trials)
if not trials_with_error:
return
self._start_block("status_errored")
print(f"Number of errored trials: {len(trials_with_error)}")
fail_header = ["Trial name", "# failures", "error file"]
fail_table_data = [
[
str(trial),
str(trial.run_metadata.num_failures)
+ ("" if trial.status == Trial.ERROR else "*"),
trial.error_file,
]
for trial in trials_with_error
]
print(
tabulate(
fail_table_data,
headers=fail_header,
tablefmt=AIR_TABULATE_TABLEFMT,
showindex=False,
colalign=("left", "right", "left"),
)
)
if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
print("* The trial terminated successfully after retrying.")
[docs]
class TrainReporter(ProgressReporter):
# the minimal verbosity threshold at which heartbeat starts getting printed.
_heartbeat_threshold = AirVerbosity.VERBOSE
_intermediate_result_verbosity = AirVerbosity.DEFAULT
_start_end_verbosity = AirVerbosity.DEFAULT
_addressing_tmpl = "Training"
def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
# Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa
if len(trials) == 0:
return
trial = trials[0]
if trial.status != Trial.RUNNING:
return " ".join(
[f"Training is in {trial.status} status.", self._time_heartbeat_str]
)
if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
iter_num = 1
else:
iter_num = trial.last_result[TRAINING_ITERATION] + 1
return " ".join(
[f"Training on iteration {iter_num}.", self._time_heartbeat_str]
)
def _print_heartbeat(self, trials, *args, force: bool = False):
print(self._get_heartbeat(trials, force_full_output=force))
def on_trial_result(
self,
iteration: int,
trials: List[Trial],
trial: Trial,
result: Dict,
**info,
):
self._last_heartbeat_time = time.time()
super().on_trial_result(
iteration=iteration, trials=trials, trial=trial, result=result, **info
)