ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback
ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback#
- class ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback(metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, filename: str = 'checkpoint', on: Union[str, List[str]] = 'validation_end')[source]#
Bases:
ray.tune.integration.pytorch_lightning.TuneCallback
PyTorch Lightning report and checkpoint callback
Saves checkpoints after each validation step. Also reports metrics to Tune, which is needed for checkpoint registration.
- Parameters
metrics – Metrics to report to Tune. If this is a list, each item describes the metric key reported to PyTorch Lightning, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to PyTorch Lightning.
filename – Filename of the checkpoint within the checkpoint directory. Defaults to “checkpoint”.
on – When to trigger checkpoint creations. Must be one of the PyTorch Lightning event hooks (less the
on_
), e.g. “train_batch_start”, or “train_end”. Defaults to “validation_end”.
Example:
import pytorch_lightning as pl from ray.tune.integration.pytorch_lightning import ( TuneReportCheckpointCallback) # Save checkpoint after each training batch and after each # validation epoch. trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback( metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, filename="trainer.ckpt", on="validation_end")])
PublicAPI: This API is stable across Ray releases.