import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict
import torch
from packaging.version import Version
import ray
from ray import train
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray.train import Checkpoint
from ray.util import PublicAPI
def import_lightning(): # noqa: F402
try:
import lightning.pytorch as pl
except ModuleNotFoundError:
import pytorch_lightning as pl
return pl
pl = import_lightning()
_LIGHTNING_GREATER_EQUAL_2_0 = Version(pl.__version__) >= Version("2.0.0")
_LIGHTNING_LESS_THAN_2_1 = Version(pl.__version__) < Version("2.1.0")
_TORCH_GREATER_EQUAL_1_12 = Version(torch.__version__) >= Version("1.12.0")
_TORCH_FSDP_AVAILABLE = _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available()
try:
from lightning.pytorch.plugins.environments import LightningEnvironment
except ModuleNotFoundError:
from pytorch_lightning.plugins.environments import LightningEnvironment
if _LIGHTNING_GREATER_EQUAL_2_0:
FSDPStrategy = pl.strategies.FSDPStrategy
else:
FSDPStrategy = pl.strategies.DDPFullyShardedStrategy
if _TORCH_FSDP_AVAILABLE:
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel,
StateDictType,
)
logger = logging.getLogger(__name__)
LIGHTNING_REPORT_STAGE_KEY = "_report_on"
[docs]@PublicAPI(stability="beta")
class RayDDPStrategy(pl.strategies.DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration.
For a full list of initialization arguments, please refer to:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DDPStrategy.html
Note that `process_group_backend`, `timeout`, and `start_method` are disabled here,
please specify these arguments in :class:`~ray.train.torch.TorchConfig` instead.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDDPSTRATEGY, "1")
@property
def root_device(self) -> torch.device:
return ray.train.torch.get_device()
@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=self.world_size,
rank=self.global_rank,
)
[docs]@PublicAPI(stability="beta")
class RayFSDPStrategy(FSDPStrategy): # noqa: F821
"""Subclass of FSDPStrategy to ensure compatibility with Ray orchestration.
For a full list of initialization arguments, please refer to:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.FSDPStrategy.html
.. note::
It is recommended to upgrade `lightning>=2.1` or above when using FSDP
with Lightning, since Lightning starts to natively support `state_dict_type`,
`sharding_strategy`, `auto_wrap_policy` and other FSDP configurations from 2.1.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYFSDPSTRATEGY, "1")
@property
def root_device(self) -> torch.device:
return ray.train.torch.get_device()
@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=self.world_size,
rank=self.global_rank,
)
[docs] def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Gathers the full state dict to rank 0 on CPU.
FSDP checkpointing is broken in Lightning 2.0.x. This subclass patches the
behavior to perform a full state dict checkpointing, gathering the checkpoint
shards on rank 0 CPU. Upgrade to `lightning>=2.1` to do sharded state dict
checkpointing.
See the note in the class docstring for more details.
"""
assert self.model is not None, "Failed to get the state dict for a None model!"
if (
_TORCH_FSDP_AVAILABLE
and _LIGHTNING_GREATER_EQUAL_2_0
and _LIGHTNING_LESS_THAN_2_1
):
with FullyShardedDataParallel.state_dict_type(
module=self.model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
),
):
state_dict = self.model.state_dict()
ckpt_state_dict = {}
prefix_len = len("_forward_module.")
for k, v in state_dict.items():
if k.startswith("_forward_module."):
non_prefixed_key = k[prefix_len:]
ckpt_state_dict[non_prefixed_key] = v
else:
ckpt_state_dict[k] = v
return ckpt_state_dict
else:
# Otherwise Lightning uses Fairscale FSDP, no need to unshard by ourself.
return super().lightning_module_state_dict()
[docs]@PublicAPI(stability="beta")
class RayDeepSpeedStrategy(pl.strategies.DeepSpeedStrategy):
"""Subclass of DeepSpeedStrategy to ensure compatibility with Ray orchestration.
For a full list of initialization arguments, please refer to:
https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.strategies.DeepSpeedStrategy.html
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDEEPSPEEDSTRATEGY, "1")
@property
def root_device(self) -> torch.device:
return ray.train.torch.get_device()
@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return dict(
num_replicas=self.world_size,
rank=self.global_rank,
)
[docs]@PublicAPI(stability="beta")
class RayLightningEnvironment(LightningEnvironment): # noqa: F821
"""Setup Lightning DDP training environment for Ray cluster."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYLIGHTNINGENVIRONMENT, "1")
def world_size(self) -> int:
return train.get_context().get_world_size()
def global_rank(self) -> int:
return train.get_context().get_world_rank()
def local_rank(self) -> int:
return train.get_context().get_local_rank()
def node_rank(self) -> int:
return train.get_context().get_node_rank()
def set_world_size(self, size: int) -> None:
# Disable it since `world_size()` directly returns data from Train context.
pass
def set_global_rank(self, rank: int) -> None:
# Disable it since `global_rank()` directly returns data from Train.
pass
def teardown(self):
pass
[docs]@PublicAPI(stability="beta")
def prepare_trainer(trainer: pl.Trainer) -> pl.Trainer:
"""Prepare the PyTorch Lightning Trainer for distributed execution."""
# Check strategy class
valid_strategy_class = [RayDDPStrategy, RayFSDPStrategy, RayDeepSpeedStrategy]
if not any(isinstance(trainer.strategy, cls) for cls in valid_strategy_class):
raise RuntimeError(
f"Invalid strategy class: {type(trainer.strategy)}. To use "
"PyTorch Lightning with Ray, the strategy object should be one of "
f"{[cls.__name__ for cls in valid_strategy_class]} class "
"or its subclass."
)
# Check cluster environment
cluster_environment = getattr(trainer.strategy, "cluster_environment", None)
if cluster_environment and not isinstance(
cluster_environment, RayLightningEnvironment
):
raise RuntimeError(
"Invalid cluster environment plugin. The expected class is"
"`ray.train.lightning.RayLightningEnvironment` "
f"but got {type(cluster_environment)}!"
)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_PREPARE_TRAINER, "1")
return trainer
[docs]@PublicAPI(stability="beta")
class RayTrainReportCallback(pl.callbacks.Callback):
"""A simple callback that reports checkpoints to Ray on train epoch end.
This callback is a subclass of `lightning.pytorch.callbacks.Callback
<https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Callback.html#lightning.pytorch.callbacks.Callback>`_.
It fetches the latest `trainer.callback_metrics` and reports together with
the checkpoint on each training epoch end.
Checkpoints will be saved in the following structure::
checkpoint_00000*/ Ray Train Checkpoint
└─ checkpoint.ckpt PyTorch Lightning Checkpoint
For customized reporting and checkpointing logic, implement your own
`lightning.pytorch.callbacks.Callback` following this user
guide: :ref:`Saving and Loading Checkpoints <train-dl-saving-checkpoints>`.
"""
CHECKPOINT_NAME = "checkpoint.ckpt"
def __init__(self) -> None:
super().__init__()
self.trial_name = train.get_context().get_trial_name()
self.local_rank = train.get_context().get_local_rank()
self.tmpdir_prefix = Path(tempfile.gettempdir(), self.trial_name).as_posix()
if os.path.isdir(self.tmpdir_prefix) and self.local_rank == 0:
shutil.rmtree(self.tmpdir_prefix)
record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYTRAINREPORTCALLBACK, "1")
def on_train_epoch_end(self, trainer, pl_module) -> None:
# Creates a checkpoint dir with fixed name
tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix()
os.makedirs(tmpdir, exist_ok=True)
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}
# (Optional) Add customized metrics
metrics["epoch"] = trainer.current_epoch
metrics["step"] = trainer.global_step
# Save checkpoint to local
ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix()
trainer.save_checkpoint(ckpt_path, weights_only=False)
# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
train.report(metrics=metrics, checkpoint=checkpoint)
# Add a barrier to ensure all workers finished reporting here
trainer.strategy.barrier()
if self.local_rank == 0:
shutil.rmtree(tmpdir)