import logging
from typing import TYPE_CHECKING, Dict
import numpy as np
from ray.air.constants import TRAINING_ITERATION
from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
from ray.tune.utils import flatten_dict
from ray.util.annotations import Deprecated, PublicAPI
from ray.util.debug import log_once
if TYPE_CHECKING:
from ray.tune.experiment.trial import Trial # noqa: F401
logger = logging.getLogger(__name__)
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
@Deprecated(
message=_LOGGER_DEPRECATION_WARNING.format(
old="TBXLogger", new="ray.tune.tensorboardx.TBXLoggerCallback"
),
warning=True,
)
@PublicAPI
class TBXLogger(Logger):
"""TensorBoardX Logger.
Note that hparams will be written only after a trial has terminated.
This logger automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
VALID_HPARAMS = (str, bool, int, float, list, type(None))
VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
def _init(self):
try:
from tensorboardX import SummaryWriter
except ImportError:
if log_once("tbx-install"):
logger.info('pip install "ray[tune]" to see TensorBoard files.')
raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
self.last_result = None
def on_result(self, result: Dict):
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
if k in tmp:
del tmp[k] # not useful to log these
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {}
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
valid_result[full_attr] = value
self._file_writer.add_scalar(full_attr, value, global_step=step)
elif (isinstance(value, list) and len(value) > 0) or (
isinstance(value, np.ndarray) and value.size > 0
):
valid_result[full_attr] = value
# Must be a single image.
if isinstance(value, np.ndarray) and value.ndim == 3:
self._file_writer.add_image(
full_attr,
value,
global_step=step,
)
continue
# Must be a batch of images.
if isinstance(value, np.ndarray) and value.ndim == 4:
self._file_writer.add_images(
full_attr,
value,
global_step=step,
)
continue
# Must be video
if isinstance(value, np.ndarray) and value.ndim == 5:
self._file_writer.add_video(
full_attr, value, global_step=step, fps=20
)
continue
try:
self._file_writer.add_histogram(full_attr, value, global_step=step)
# In case TensorboardX still doesn't think it's a valid value
# (e.g. `[[]]`), warn and move on.
except (ValueError, TypeError):
if log_once("invalid_tbx_value"):
logger.warning(
"You are trying to log an invalid value ({}={}) "
"via {}!".format(full_attr, value, type(self).__name__)
)
self.last_result = valid_result
self._file_writer.flush()
def flush(self):
if self._file_writer is not None:
self._file_writer.flush()
def close(self):
if self._file_writer is not None:
if self.trial and self.trial.evaluated_params and self.last_result:
flat_result = flatten_dict(self.last_result, delimiter="/")
scrubbed_result = {
k: value
for k, value in flat_result.items()
if isinstance(value, tuple(VALID_SUMMARY_TYPES))
}
self._try_log_hparams(scrubbed_result)
self._file_writer.close()
def _try_log_hparams(self, result):
# TBX currently errors if the hparams value is None.
flat_params = flatten_dict(self.trial.evaluated_params)
scrubbed_params = {
k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
}
np_params = {
k: v.tolist()
for k, v in flat_params.items()
if isinstance(v, self.VALID_NP_HPARAMS)
}
scrubbed_params.update(np_params)
removed = {
k: v
for k, v in flat_params.items()
if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
}
if removed:
logger.info(
"Removed the following hyperparameter values when "
"logging to tensorboard: %s",
str(removed),
)
from tensorboardX.summary import hparams
try:
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=scrubbed_params, metric_dict=result
)
self._file_writer.file_writer.add_summary(experiment_tag)
self._file_writer.file_writer.add_summary(session_start_tag)
self._file_writer.file_writer.add_summary(session_end_tag)
except Exception:
logger.exception(
"TensorboardX failed to log hparams. "
"This may be due to an unsupported type "
"in the hyperparameter values."
)
[docs]
@PublicAPI
class TBXLoggerCallback(LoggerCallback):
"""TensorBoardX Logger.
Note that hparams will be written only after a trial has terminated.
This logger automatically flattens nested dicts to show on TensorBoard:
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
"""
_SAVED_FILE_TEMPLATES = ["events.out.tfevents.*"]
VALID_HPARAMS = (str, bool, int, float, list, type(None))
VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
def __init__(self):
try:
from tensorboardX import SummaryWriter
self._summary_writer_cls = SummaryWriter
except ImportError:
if log_once("tbx-install"):
logger.info('pip install "ray[tune]" to see TensorBoard files.')
raise
self._trial_writer: Dict["Trial", SummaryWriter] = {}
self._trial_result: Dict["Trial", Dict] = {}
def log_trial_start(self, trial: "Trial"):
if trial in self._trial_writer:
self._trial_writer[trial].close()
trial.init_local_path()
self._trial_writer[trial] = self._summary_writer_cls(
trial.local_path, flush_secs=30
)
self._trial_result[trial] = {}
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
if trial not in self._trial_writer:
self.log_trial_start(trial)
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
tmp = result.copy()
for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
if k in tmp:
del tmp[k] # not useful to log these
flat_result = flatten_dict(tmp, delimiter="/")
path = ["ray", "tune"]
valid_result = {}
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value):
valid_result[full_attr] = value
self._trial_writer[trial].add_scalar(full_attr, value, global_step=step)
elif (isinstance(value, list) and len(value) > 0) or (
isinstance(value, np.ndarray) and value.size > 0
):
valid_result[full_attr] = value
# Must be a single image.
if isinstance(value, np.ndarray) and value.ndim == 3:
self._trial_writer[trial].add_image(
full_attr,
value,
global_step=step,
)
continue
# Must be a batch of images.
if isinstance(value, np.ndarray) and value.ndim == 4:
self._trial_writer[trial].add_images(
full_attr,
value,
global_step=step,
)
continue
# Must be video
if isinstance(value, np.ndarray) and value.ndim == 5:
self._trial_writer[trial].add_video(
full_attr, value, global_step=step, fps=20
)
continue
try:
self._trial_writer[trial].add_histogram(
full_attr, value, global_step=step
)
# In case TensorboardX still doesn't think it's a valid value
# (e.g. `[[]]`), warn and move on.
except (ValueError, TypeError):
if log_once("invalid_tbx_value"):
logger.warning(
"You are trying to log an invalid value ({}={}) "
"via {}!".format(full_attr, value, type(self).__name__)
)
self._trial_result[trial] = valid_result
self._trial_writer[trial].flush()
def log_trial_end(self, trial: "Trial", failed: bool = False):
if trial in self._trial_writer:
if trial and trial.evaluated_params and self._trial_result[trial]:
flat_result = flatten_dict(self._trial_result[trial], delimiter="/")
scrubbed_result = {
k: value
for k, value in flat_result.items()
if isinstance(value, tuple(VALID_SUMMARY_TYPES))
}
self._try_log_hparams(trial, scrubbed_result)
self._trial_writer[trial].close()
del self._trial_writer[trial]
del self._trial_result[trial]
def _try_log_hparams(self, trial: "Trial", result: Dict):
# TBX currently errors if the hparams value is None.
flat_params = flatten_dict(trial.evaluated_params)
scrubbed_params = {
k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
}
np_params = {
k: v.tolist()
for k, v in flat_params.items()
if isinstance(v, self.VALID_NP_HPARAMS)
}
scrubbed_params.update(np_params)
removed = {
k: v
for k, v in flat_params.items()
if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
}
if removed:
logger.info(
"Removed the following hyperparameter values when "
"logging to tensorboard: %s",
str(removed),
)
from tensorboardX.summary import hparams
try:
experiment_tag, session_start_tag, session_end_tag = hparams(
hparam_dict=scrubbed_params, metric_dict=result
)
self._trial_writer[trial].file_writer.add_summary(experiment_tag)
self._trial_writer[trial].file_writer.add_summary(session_start_tag)
self._trial_writer[trial].file_writer.add_summary(session_end_tag)
except Exception:
logger.exception(
"TensorboardX failed to log hparams. "
"This may be due to an unsupported type "
"in the hyperparameter values."
)