import inspect
import logging
import os
import tempfile
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union
from pytorch_lightning import Callback, Trainer, LightningModule
from ray import train
from ray.util import log_once
from ray.util.annotations import PublicAPI, Deprecated
from ray.train import Checkpoint
logger = logging.getLogger(__name__)
# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
_allowed_hooks = {
name
for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
if name.startswith("on_")
}
def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]:
"""Overrides all allowed PTL Callback hooks with our custom handle logic."""
def generate_overridden_hook(fn_name):
def overridden_hook(
self,
trainer: Trainer,
*args,
pl_module: Optional[LightningModule] = None,
**kwargs,
):
if fn_name in self._on:
self._handle(trainer=trainer, pl_module=pl_module)
return overridden_hook
# Set the overridden hook to all the allowed hooks in TuneCallback.
for fn_name in _allowed_hooks:
setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))
return callback_cls
@_override_ptl_hooks
class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks.
Args:
When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end"
"""
def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]
for hook in on:
if f"on_{hook}" not in _allowed_hooks:
raise ValueError(
f"Invalid hook selected: {hook}. Must be one of "
f"{_allowed_hooks}"
)
# Add back the "on_" prefix for internal consistency.
on = [f"on_{hook}" for hook in on]
self._on = on
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError
[docs]@PublicAPI
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback
Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.
Args:
metrics: Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
save_checkpoints: If True (default), checkpoints will be saved and
reported to Ray. If False, only metrics will be reported.
on: When to trigger checkpoint creations and metric reports. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"train_batch_start", or "train_end". Defaults to "validation_end".
Example:
.. code-block:: python
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import (
TuneReportCheckpointCallback)
# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])
"""
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = "checkpoint",
save_checkpoints: bool = True,
on: Union[str, List[str]] = "validation_end",
):
super(TuneReportCheckpointCallback, self).__init__(on=on)
if isinstance(metrics, str):
metrics = [metrics]
self._save_checkpoints = save_checkpoints
self._filename = filename
self._metrics = metrics
def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule):
# Don't report if just doing initial validation sanity checks.
if trainer.sanity_checking:
return
if not self._metrics:
report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()}
else:
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
if metric in trainer.callback_metrics:
report_dict[key] = trainer.callback_metrics[metric].item()
else:
logger.warning(
f"Metric {metric} does not exist in "
"`trainer.callback_metrics."
)
return report_dict
@contextmanager
def _get_checkpoint(self, trainer: Trainer) -> Optional[Checkpoint]:
if not self._save_checkpoints:
yield None
return
with tempfile.TemporaryDirectory() as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename))
checkpoint = Checkpoint.from_directory(checkpoint_dir)
yield checkpoint
def _handle(self, trainer: Trainer, pl_module: LightningModule):
if trainer.sanity_checking:
return
report_dict = self._get_report_dict(trainer, pl_module)
if not report_dict:
return
with self._get_checkpoint(trainer) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
class _TuneCheckpointCallback(TuneCallback):
def __init__(self, *args, **kwargs):
raise DeprecationWarning(
"`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` "
"is deprecated."
)
[docs]@Deprecated
class TuneReportCallback(TuneReportCheckpointCallback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "validation_end",
):
if log_once("tune_ptl_report_deprecated"):
warnings.warn(
"`ray.tune.integration.pytorch_lightning.TuneReportCallback` "
"is deprecated. Use "
"`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`"
" instead."
)
super(TuneReportCallback, self).__init__(
metrics=metrics, save_checkpoints=False, on=on
)