Source code for ray.train.lightning._lightning_utils

import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict

import torch
from packaging.version import Version

import ray
import ray.train
from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.train import Checkpoint
from ray.util import PublicAPI


def import_lightning():  # noqa: F402
    try:
        import lightning.pytorch as pl
    except ModuleNotFoundError:
        import pytorch_lightning as pl
    return pl


pl = import_lightning()

_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_LIGHTNING_LESS_THAN_2_1 = Version(pl.__version__) < Version("2.1.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()

try:
    from lightning.pytorch.plugins.environments import LightningEnvironment
except ModuleNotFoundError:
    from pytorch_lightning.plugins.environments import LightningEnvironment

if _LIGHTNING_GREATER_EQUAL_2_0:
    FSDPStrategy = pl.strategies.FSDPStrategy
else:
    FSDPStrategy = pl.strategies.DDPFullyShardedStrategy

if _TORCH_FSDP_AVAILABLE:
    from torch.distributed.fsdp import (
        FullStateDictConfig,
        FullyShardedDataParallel,
        StateDictType,
    )


logger = logging.getLogger(__name__)

LIGHTNING_REPORT_STAGE_KEY = "_report_on"


[docs] @PublicAPI(stability="beta") class RayDDPStrategy(pl.strategies.DDPStrategy): """Subclass of DDPStrategy to ensure compatibility with Ray orchestration. For a full list of initialization arguments, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html Note that `process_group_backend`, `timeout`, and `start_method` are disabled here, please specify these arguments in :class:`~ray.train.torch.TorchConfig` instead. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDDPSTRATEGY, "1") @property def root_device(self) -> torch.device: return ray.train.torch.get_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: return dict( num_replicas=self.world_size, rank=self.global_rank, )
[docs] @PublicAPI(stability="beta") class RayFSDPStrategy(FSDPStrategy): # noqa: F821 """Subclass of FSDPStrategy to ensure compatibility with Ray orchestration. For a full list of initialization arguments, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html .. note:: It is recommended to upgrade `lightning>=2.1` or above when using FSDP with Lightning, since Lightning starts to natively support `state_dict_type`, `sharding_strategy`, `auto_wrap_policy` and other FSDP configurations from 2.1. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYFSDPSTRATEGY, "1") @property def root_device(self) -> torch.device: return ray.train.torch.get_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: return dict( num_replicas=self.world_size, rank=self.global_rank, )
[docs] def lightning_module_state_dict(self) -> Dict[str, Any]: """Gathers the full state dict to rank 0 on CPU. FSDP checkpointing is broken in Lightning 2.0.x. This subclass patches the behavior to perform a full state dict checkpointing, gathering the checkpoint shards on rank 0 CPU. Upgrade to `lightning>=2.1` to do sharded state dict checkpointing. See the note in the class docstring for more details. """ assert self.model is not None, "Failed to get the state dict for a None model!" if ( _TORCH_FSDP_AVAILABLE and _LIGHTNING_GREATER_EQUAL_2_0 and _LIGHTNING_LESS_THAN_2_1 ): with FullyShardedDataParallel.state_dict_type( module=self.model, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig( offload_to_cpu=True, rank0_only=True ), ): state_dict = self.model.state_dict() ckpt_state_dict = {} prefix_len = len("_forward_module.") for k, v in state_dict.items(): if k.startswith("_forward_module."): non_prefixed_key = k[prefix_len:] ckpt_state_dict[non_prefixed_key] = v else: ckpt_state_dict[k] = v return ckpt_state_dict else: # Otherwise Lightning uses Fairscale FSDP, no need to unshard by ourself. return super().lightning_module_state_dict()
[docs] @PublicAPI(stability="beta") class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy): """Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration. For a full list of initialization arguments, please refer to: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DeepSpeedStrategy.html """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDEEPSPEEDSTRATEGY, "1") @property def root_device(self) -> torch.device: return ray.train.torch.get_device() @property def distributed_sampler_kwargs(self) -> Dict[str, Any]: return dict( num_replicas=self.world_size, rank=self.global_rank, )
[docs] @PublicAPI(stability="beta") class RayLightningEnvironment(LightningEnvironment): # noqa: F821 """Setup Lightning DDP training environment for Ray cluster.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYLIGHTNINGENVIRONMENT, "1") def world_size(self) -> int: return ray.train.get_context().get_world_size() def global_rank(self) -> int: return ray.train.get_context().get_world_rank() def local_rank(self) -> int: return ray.train.get_context().get_local_rank() def node_rank(self) -> int: return ray.train.get_context().get_node_rank() def set_world_size(self, size: int) -> None: # Disable it since `world_size()` directly returns data from Train context. pass def set_global_rank(self, rank: int) -> None: # Disable it since `global_rank()` directly returns data from Train. pass def teardown(self): pass
[docs] @PublicAPI(stability="beta") def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer: """Prepare the PyTorch Lightning Trainer for distributed execution.""" # Check strategy class valid_strategy_class = [RayDDPStrategy, RayFSDPStrategy, RayDeepSpeedStrategy] if not any(isinstance(trainer.strategy, cls) for cls in valid_strategy_class): raise RuntimeError( f"Invalid strategy class: {type(trainer.strategy)}. To use " "PyTorch Lightning with Ray, the strategy object should be one of " f"{[cls.__name__ for cls in valid_strategy_class]} class " "or its subclass." ) # Check cluster environment cluster_environment = getattr(trainer.strategy, "cluster_environment", None) if cluster_environment and not isinstance( cluster_environment, RayLightningEnvironment ): raise RuntimeError( "Invalid cluster environment plugin. The expected class is" "`ray.train.lightning.RayLightningEnvironment` " f"but got {type(cluster_environment)}!" ) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_PREPARE_TRAINER, "1") return trainer
[docs] @PublicAPI(stability="beta") class RayTrainReportCallback(pl.callbacks.Callback): """A simple callback that reports checkpoints to Ray on train epoch end. This callback is a subclass of `lightning.pytorch.callbacks.Callback <https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback>`_. It fetches the latest `trainer.callback_metrics` and reports together with the checkpoint on each training epoch end. Checkpoints will be saved in the following structure:: checkpoint_00000*/ Ray Train Checkpoint └─ checkpoint.ckpt PyTorch Lightning Checkpoint For customized reporting and checkpointing logic, implement your own `lightning.pytorch.callbacks.Callback` following this user guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`. """ CHECKPOINT_NAME = "checkpoint.ckpt" def __init__(self) -> None: super().__init__() job_id = ray.get_runtime_context().get_job_id() experiment_name = ray.train.get_context().get_experiment_name() self.local_rank = ray.train.get_context().get_local_rank() self.tmpdir_prefix = Path( tempfile.gettempdir(), f"lightning_checkpoints-job_id={job_id}-name={experiment_name}", ).as_posix() if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0: shutil.rmtree(self.tmpdir_prefix) record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1") def on_train_epoch_end(self, trainer, pl_module) -> None: # Creates a checkpoint dir with fixed name tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix() os.makedirs(tmpdir, exist_ok=True) # Fetch metrics metrics = trainer.callback_metrics metrics = {k: v.item() for k, v in metrics.items()} # (Optional) Add customized metrics metrics["epoch"] = trainer.current_epoch metrics["step"] = trainer.global_step # Save checkpoint to local ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix() trainer.save_checkpoint(ckpt_path, weights_only=False) # Report to train session checkpoint = Checkpoint.from_directory(tmpdir) ray.train.report(metrics=metrics, checkpoint=checkpoint) # Add a barrier to ensure all workers finished reporting here trainer.strategy.barrier() if self.local_rank == 0: shutil.rmtree(tmpdir)