Source code for ray.tune.integration.pytorch_lightning

import inspect
import logging
import os
import tempfile
import warnings
from contextlib import contextmanager
from typing import Dict, List, Optional, Type, Union

from pytorch_lightning import Callback, Trainer, LightningModule
from ray import train
from ray.util import log_once
from ray.util.annotations import PublicAPI, Deprecated
from ray.train import Checkpoint


logger = logging.getLogger(__name__)

# Get all Pytorch Lightning Callback hooks based on whatever PTL version is being used.
_allowed_hooks = {
    name
    for name, fn in inspect.getmembers(Callback, predicate=inspect.isfunction)
    if name.startswith("on_")
}


def _override_ptl_hooks(callback_cls: Type["TuneCallback"]) -> Type["TuneCallback"]:
    """Overrides all allowed PTL Callback hooks with our custom handle logic."""

    def generate_overridden_hook(fn_name):
        def overridden_hook(
            self,
            trainer: Trainer,
            *args,
            pl_module: Optional[LightningModule] = None,
            **kwargs,
        ):
            if fn_name in self._on:
                self._handle(trainer=trainer, pl_module=pl_module)

        return overridden_hook

    # Set the overridden hook to all the allowed hooks in TuneCallback.
    for fn_name in _allowed_hooks:
        setattr(callback_cls, fn_name, generate_overridden_hook(fn_name))

    return callback_cls


@_override_ptl_hooks
class TuneCallback(Callback):
    """Base class for Tune's PyTorch Lightning callbacks.

    Args:
        When to trigger checkpoint creations. Must be one of
        the PyTorch Lightning event hooks (less the ``on_``), e.g.
        "train_batch_start", or "train_end". Defaults to "validation_end"
    """

    def __init__(self, on: Union[str, List[str]] = "validation_end"):
        if not isinstance(on, list):
            on = [on]

        for hook in on:
            if f"on_{hook}" not in _allowed_hooks:
                raise ValueError(
                    f"Invalid hook selected: {hook}. Must be one of "
                    f"{_allowed_hooks}"
                )

        # Add back the "on_" prefix for internal consistency.
        on = [f"on_{hook}" for hook in on]

        self._on = on

    def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
        raise NotImplementedError


[docs]@PublicAPI class TuneReportCheckpointCallback(TuneCallback): """PyTorch Lightning report and checkpoint callback Saves checkpoints after each validation step. Also reports metrics to Tune, which is needed for checkpoint registration. Args: metrics: Metrics to report to Tune. If this is a list, each item describes the metric key reported to PyTorch Lightning, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to PyTorch Lightning. filename: Filename of the checkpoint within the checkpoint directory. Defaults to "checkpoint". save_checkpoints: If True (default), checkpoints will be saved and reported to Ray. If False, only metrics will be reported. on: When to trigger checkpoint creations and metric reports. Must be one of the PyTorch Lightning event hooks (less the ``on_``), e.g. "train_batch_start", or "train_end". Defaults to "validation_end". Example: .. code-block:: python import pytorch_lightning as pl from ray.tune.integration.pytorch_lightning import ( TuneReportCheckpointCallback) # Save checkpoint after each training batch and after each # validation epoch. trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback( metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, filename="trainer.ckpt", on="validation_end")]) """ def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, filename: str = "checkpoint", save_checkpoints: bool = True, on: Union[str, List[str]] = "validation_end", ): super(TuneReportCheckpointCallback, self).__init__(on=on) if isinstance(metrics, str): metrics = [metrics] self._save_checkpoints = save_checkpoints self._filename = filename self._metrics = metrics def _get_report_dict(self, trainer: Trainer, pl_module: LightningModule): # Don't report if just doing initial validation sanity checks. if trainer.sanity_checking: return if not self._metrics: report_dict = {k: v.item() for k, v in trainer.callback_metrics.items()} else: report_dict = {} for key in self._metrics: if isinstance(self._metrics, dict): metric = self._metrics[key] else: metric = key if metric in trainer.callback_metrics: report_dict[key] = trainer.callback_metrics[metric].item() else: logger.warning( f"Metric {metric} does not exist in " "`trainer.callback_metrics." ) return report_dict @contextmanager def _get_checkpoint(self, trainer: Trainer) -> Optional[Checkpoint]: if not self._save_checkpoints: yield None return with tempfile.TemporaryDirectory() as checkpoint_dir: trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename)) checkpoint = Checkpoint.from_directory(checkpoint_dir) yield checkpoint def _handle(self, trainer: Trainer, pl_module: LightningModule): if trainer.sanity_checking: return report_dict = self._get_report_dict(trainer, pl_module) if not report_dict: return with self._get_checkpoint(trainer) as checkpoint: train.report(report_dict, checkpoint=checkpoint)
class _TuneCheckpointCallback(TuneCallback): def __init__(self, *args, **kwargs): raise DeprecationWarning( "`ray.tune.integration.pytorch_lightning._TuneCheckpointCallback` " "is deprecated." ) @Deprecated class TuneReportCallback(TuneReportCheckpointCallback): def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, on: Union[str, List[str]] = "validation_end", ): if log_once("tune_ptl_report_deprecated"): warnings.warn( "`ray.tune.integration.pytorch_lightning.TuneReportCallback` " "is deprecated. Use " "`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`" " instead." ) super(TuneReportCallback, self).__init__( metrics=metrics, save_checkpoints=False, on=on )