import io
import json
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import pandas as pd
import pyarrow
import ray
from ray.air.constants import (
EXPR_ERROR_PICKLE_FILE,
EXPR_PROGRESS_FILE,
EXPR_RESULT_FILE,
)
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.train import Checkpoint
logger = logging.getLogger(__name__)
@PublicAPI(stability="stable")
@dataclass
class Result:
"""The final result of a ML training run or a Tune trial.
This is the output produced by ``Trainer.fit``.
``Tuner.fit`` outputs a :class:`~ray.tune.ResultGrid` that is a collection
of ``Result`` objects.
This API is the recommended way to access the outputs such as:
- checkpoints (``Result.checkpoint``)
- the history of reported metrics (``Result.metrics_dataframe``, ``Result.metrics``)
- errors encountered during a training run (``Result.error``)
The constructor is a private API -- use ``Result.from_path`` to create a result
object from a directory.
Attributes:
metrics: The latest set of reported metrics.
checkpoint: The latest checkpoint.
error: The execution error of the Trainable run, if the trial finishes in error.
path: Path pointing to the result directory on persistent storage. This can
point to a remote storage location (e.g. S3) or to a local location (path
on the head node). The path is accessible via the result's associated
`filesystem`. For instance, for a result stored in S3 at
``s3://bucket/location``, ``path`` will have the value ``bucket/location``.
metrics_dataframe: The full result dataframe of the Trainable.
The dataframe is indexed by iterations and contains reported
metrics. Note that the dataframe columns are indexed with the
*flattened* keys of reported metrics, so the format of this dataframe
may be slightly different than ``Result.metrics``, which is an unflattened
dict of the latest set of reported metrics.
best_checkpoints: A list of tuples of the best checkpoints and
their associated metrics. The number of
saved checkpoints is determined by :class:`~ray.train.CheckpointConfig`
(by default, all checkpoints will be saved).
"""
metrics: Optional[Dict[str, Any]]
checkpoint: Optional["Checkpoint"]
error: Optional[Exception]
path: str
metrics_dataframe: Optional["pd.DataFrame"] = None
best_checkpoints: Optional[List[Tuple["Checkpoint", Dict[str, Any]]]] = None
_storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
_items_to_repr = ["error", "metrics", "path", "filesystem", "checkpoint"]
@property
def config(self) -> Optional[Dict[str, Any]]:
"""The config associated with the result."""
if not self.metrics:
return None
return self.metrics.get("config", None)
@property
def filesystem(self) -> pyarrow.fs.FileSystem:
"""Return the filesystem that can be used to access the result path.
Returns:
pyarrow.fs.FileSystem implementation.
"""
return self._storage_filesystem or pyarrow.fs.LocalFileSystem()
def _repr(self, indent: int = 0) -> str:
"""Construct the representation with specified number of space indent."""
from ray.tune.experimental.output import BLACKLISTED_KEYS
from ray.tune.result import AUTO_RESULT_KEYS
shown_attributes = {k: getattr(self, k) for k in self._items_to_repr}
if self.error:
shown_attributes["error"] = type(self.error).__name__
else:
shown_attributes.pop("error")
shown_attributes["filesystem"] = shown_attributes["filesystem"].type_name
if self.metrics:
exclude = set(AUTO_RESULT_KEYS)
exclude.update(BLACKLISTED_KEYS)
shown_attributes["metrics"] = {
k: v for k, v in self.metrics.items() if k not in exclude
}
cls_indent = " " * indent
kws_indent = " " * (indent + 2)
kws = [
f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items()
]
kws_repr = ",\n".join(kws)
return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr)
def __repr__(self) -> str:
return self._repr(indent=0)
@staticmethod
def _read_file_as_str(
storage_filesystem: pyarrow.fs.FileSystem,
storage_path: str,
) -> str:
"""Opens a file as an input stream reading all byte content sequentially and
decoding read bytes as utf-8 string.
Args:
storage_filesystem: The filesystem to use.
storage_path: The source to open for reading.
"""
with storage_filesystem.open_input_stream(storage_path) as f:
return f.readall().decode()
[docs]
@classmethod
def from_path(
cls,
path: Union[str, os.PathLike],
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
) -> "Result":
"""Restore a Result object from local or remote trial directory.
Args:
path: A path of a trial directory on local or remote storage
(ex: s3://bucket/path or /tmp/ray_results).
storage_filesystem: A custom filesystem to use. If not provided,
this will be auto-resolved by pyarrow. If provided, the path
is assumed to be prefix-stripped already, and must be a valid path
on the filesystem.
Returns:
A :py:class:`Result` object of that trial.
"""
# TODO(justinvyu): Fix circular dependency.
from ray.train import Checkpoint
from ray.train._internal.storage import (
_exists_at_fs_path,
_list_at_fs_path,
get_fs_and_path,
)
from ray.train.constants import CHECKPOINT_DIR_NAME
fs, fs_path = get_fs_and_path(path, storage_filesystem)
if not _exists_at_fs_path(fs, fs_path):
raise RuntimeError(f"Trial folder {fs_path} doesn't exist!")
# Restore metrics from result.json
result_json_file = Path(fs_path, EXPR_RESULT_FILE).as_posix()
progress_csv_file = Path(fs_path, EXPR_PROGRESS_FILE).as_posix()
if _exists_at_fs_path(fs, result_json_file):
lines = cls._read_file_as_str(fs, result_json_file).split("\n")
json_list = [json.loads(line) for line in lines if line]
metrics_df = pd.json_normalize(json_list, sep="/")
latest_metrics = json_list[-1] if json_list else {}
# Fallback to restore from progress.csv
elif _exists_at_fs_path(fs, progress_csv_file):
metrics_df = pd.read_csv(
io.StringIO(cls._read_file_as_str(fs, progress_csv_file))
)
latest_metrics = (
metrics_df.iloc[-1].to_dict() if not metrics_df.empty else {}
)
else:
raise RuntimeError(
f"Failed to restore the Result object: Neither {EXPR_RESULT_FILE}"
f" nor {EXPR_PROGRESS_FILE} exists in the trial folder!"
)
# Restore all checkpoints from the checkpoint folders
checkpoint_dir_names = sorted(
_list_at_fs_path(
fs,
fs_path,
file_filter=lambda file_info: file_info.type
== pyarrow.fs.FileType.Directory
and file_info.base_name.startswith("checkpoint_"),
)
)
if checkpoint_dir_names:
checkpoints = [
Checkpoint(
path=Path(fs_path, checkpoint_dir_name).as_posix(), filesystem=fs
)
for checkpoint_dir_name in checkpoint_dir_names
]
metrics = []
for checkpoint_dir_name in checkpoint_dir_names:
metrics_corresponding_to_checkpoint = metrics_df[
metrics_df[CHECKPOINT_DIR_NAME] == checkpoint_dir_name
]
if metrics_corresponding_to_checkpoint.empty:
logger.warning(
"Could not find metrics corresponding to "
f"{checkpoint_dir_name}. These will default to an empty dict."
)
metrics.append(
{}
if metrics_corresponding_to_checkpoint.empty
else metrics_corresponding_to_checkpoint.iloc[-1].to_dict()
)
latest_checkpoint = checkpoints[-1]
# TODO(justinvyu): These are ordered by checkpoint index, since we don't
# know the metric to order these with.
best_checkpoints = list(zip(checkpoints, metrics))
else:
best_checkpoints = latest_checkpoint = None
# Restore the trial error if it exists
error = None
error_file_path = Path(fs_path, EXPR_ERROR_PICKLE_FILE).as_posix()
if _exists_at_fs_path(fs, error_file_path):
with fs.open_input_stream(error_file_path) as f:
error = ray.cloudpickle.load(f)
return Result(
metrics=latest_metrics,
checkpoint=latest_checkpoint,
path=fs_path,
_storage_filesystem=fs,
metrics_dataframe=metrics_df,
best_checkpoints=best_checkpoints,
error=error,
)
[docs]
@PublicAPI(stability="alpha")
def get_best_checkpoint(self, metric: str, mode: str) -> Optional["Checkpoint"]:
"""Get the best checkpoint from this trial based on a specific metric.
Any checkpoints without an associated metric value will be filtered out.
Args:
metric: The key for checkpoints to order on.
mode: One of ["min", "max"].
Returns:
:class:`Checkpoint <ray.train.Checkpoint>` object, or None if there is
no valid checkpoint associated with the metric.
"""
if not self.best_checkpoints:
raise RuntimeError("No checkpoint exists in the trial directory!")
if mode not in ["max", "min"]:
raise ValueError(
f'Unsupported mode: {mode}. Please choose from ["min", "max"]!'
)
op = max if mode == "max" else min
valid_checkpoints = [
ckpt_info for ckpt_info in self.best_checkpoints if metric in ckpt_info[1]
]
if not valid_checkpoints:
raise RuntimeError(
f"Invalid metric name {metric}! "
f"You may choose from the following metrics: {self.metrics.keys()}."
)
return op(valid_checkpoints, key=lambda x: x[1][metric])[0]