Source code for ray.train.v2.lightning.lightning_utils

import os
import shutil
import tempfile
from pathlib import Path

import ray.train
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.train.lightning._lightning_utils import (
    RayTrainReportCallback as RayTrainReportCallbackV1,
)
from ray.train.lightning._lightning_utils import import_lightning
from ray.util import PublicAPI

pl = import_lightning()


[docs] @PublicAPI(stability="beta") class RayTrainReportCallback(RayTrainReportCallbackV1): """A simple callback that reports checkpoints to Ray on train epoch end. This callback is a subclass of `lightning.pytorch.callbacks.Callback <https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback>`_. It fetches the latest `trainer.callback_metrics` and reports together with the checkpoint on each training epoch end. Checkpoints will be saved in the following structure: checkpoint_{timestamp}/ Ray Train's checkpoint folder └─ checkpoint.ckpt Lightning's checkpoint format For customized reporting and checkpointing logic, implement your own `lightning.pytorch.callbacks.Callback` following this user guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`. """ def __init__(self) -> None: # TODO: Upstream this change into ray.train.lightning. # The difference in this version is removing the trial directory usage. job_id = ray.get_runtime_context().get_job_id() experiment_name = ray.train.get_context().get_experiment_name() self.local_rank = ray.train.get_context().get_local_rank() # Create a root temporary directory for storing local checkpoints # before persisting to storage. # Lightning's checkpointing implementation requires that this directory # is a common path across all workers. # Construct the path prefix with the job id and experiment name, # which are shared across workers for a Ray Train run. # This path should not be shared across different Ray Train runs. self.tmpdir_prefix = Path( tempfile.gettempdir(), f"lightning_checkpoints-job_id={job_id}-name={experiment_name}", ).as_posix() if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: shutil.rmtree(self.tmpdir_prefix) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1")