import asyncio
import inspect
import os
import time
from pathlib import Path
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Optional,
TypeVar,
Union,
)
from pydantic import Field, field_validator
from ray.llm._internal.common.base_pydantic import BaseModelExtended
from ray.llm._internal.common.observability.logging import get_logger
from ray.llm._internal.common.utils.cloud_filesystem import (
AzureFileSystem,
GCSFileSystem,
PyArrowFileSystem,
S3FileSystem,
)
T = TypeVar("T")
logger = get_logger(__name__)
def is_remote_path(path: str) -> bool:
"""Check if the path is a remote path.
Args:
path: The path to check.
Returns:
True if the path is a remote path, False otherwise.
"""
return (
path.startswith("s3://")
or path.startswith("gs://")
or path.startswith("abfss://")
or path.startswith("azure://")
or path.startswith("pyarrow-")
)
class ExtraFiles(BaseModelExtended):
bucket_uri: str
destination_path: str
class CloudMirrorConfig(BaseModelExtended):
"""Unified mirror config for cloud storage (S3, GCS, or Azure).
Args:
bucket_uri: URI of the bucket (s3://, gs://, abfss://, or azure://)
extra_files: Additional files to download
"""
bucket_uri: Optional[str] = None
extra_files: List[ExtraFiles] = Field(default_factory=list)
@property
def storage_type(self) -> str:
"""Returns the storage type ('s3', 'gcs', 'abfss', or 'azure') based on the URI prefix."""
if self.bucket_uri is None:
return None
elif self.bucket_uri.startswith("s3://"):
return "s3"
elif self.bucket_uri.startswith("gs://"):
return "gcs"
elif self.bucket_uri.startswith("abfss://"):
return "abfss"
elif self.bucket_uri.startswith("azure://"):
return "azure"
return None
class LoraMirrorConfig(BaseModelExtended):
lora_model_id: str
bucket_uri: str
max_total_tokens: Optional[int]
sync_args: Optional[List[str]] = None
@field_validator("bucket_uri")
@classmethod
def check_uri_format(cls, value):
if value is None:
return value
if not is_remote_path(value):
raise ValueError(
f'Got invalid value "{value}" for bucket_uri. '
'Expected a URI that starts with "s3://", "gs://", "abfss://", or "azure://".'
)
return value
@property
def _bucket_name_and_path(self) -> str:
for prefix in ["s3://", "gs://", "abfss://", "azure://"]:
if self.bucket_uri.startswith(prefix):
return self.bucket_uri[len(prefix) :]
return self.bucket_uri
@property
def bucket_name(self) -> str:
bucket_part = self._bucket_name_and_path.split("/")[0]
# For ABFSS and Azure URIs, extract container name from container@account format
if self.bucket_uri.startswith(("abfss://", "azure://")) and "@" in bucket_part:
return bucket_part.split("@")[0]
return bucket_part
@property
def bucket_path(self) -> str:
return "/".join(self._bucket_name_and_path.split("/")[1:])
class CloudFileSystem:
"""A unified interface for cloud file system operations.
This class provides a simple interface for common operations on cloud storage
systems (S3, GCS, Azure) by delegating to provider-specific implementations
for optimal performance.
"""
@staticmethod
def _get_provider_fs(bucket_uri: str):
"""Get the appropriate provider-specific filesystem class based on URI.
Args:
bucket_uri: URI of the cloud storage (s3://, gs://, abfss://, or azure://)
Returns:
The appropriate filesystem class (S3FileSystem, GCSFileSystem, or AzureFileSystem)
Raises:
ValueError: If the URI scheme is not supported
"""
if bucket_uri.startswith("pyarrow-"):
return PyArrowFileSystem
elif bucket_uri.startswith("s3://"):
return S3FileSystem
elif bucket_uri.startswith("gs://"):
return GCSFileSystem
elif bucket_uri.startswith(("abfss://", "azure://")):
return AzureFileSystem
else:
raise ValueError(f"Unsupported URI scheme: {bucket_uri}")
@staticmethod
def get_file(
object_uri: str, decode_as_utf_8: bool = True
) -> Optional[Union[str, bytes]]:
"""Download a file from cloud storage into memory.
Args:
object_uri: URI of the file (s3://, gs://, abfss://, or azure://)
decode_as_utf_8: If True, decode the file as UTF-8
Returns:
File contents as string or bytes, or None if file doesn't exist
"""
fs_class = CloudFileSystem._get_provider_fs(object_uri)
return fs_class.get_file(object_uri, decode_as_utf_8)
@staticmethod
def list_subfolders(folder_uri: str) -> List[str]:
"""List the immediate subfolders in a cloud directory.
Args:
folder_uri: URI of the directory (s3://, gs://, abfss://, or azure://)
Returns:
List of subfolder names (without trailing slashes)
"""
fs_class = CloudFileSystem._get_provider_fs(folder_uri)
return fs_class.list_subfolders(folder_uri)
@staticmethod
def download_files(
path: str,
bucket_uri: str,
substrings_to_include: Optional[List[str]] = None,
suffixes_to_exclude: Optional[List[str]] = None,
) -> None:
"""Download files from cloud storage to a local directory.
Args:
path: Local directory where files will be downloaded
bucket_uri: URI of cloud directory
substrings_to_include: Only include files containing these substrings
suffixes_to_exclude: Exclude certain files from download (e.g .safetensors)
"""
fs_class = CloudFileSystem._get_provider_fs(bucket_uri)
fs_class.download_files(
path, bucket_uri, substrings_to_include, suffixes_to_exclude
)
@staticmethod
def download_model(
destination_path: str,
bucket_uri: str,
tokenizer_only: bool,
exclude_safetensors: bool = False,
) -> None:
"""Download a model from cloud storage.
This downloads a model in the format expected by the HuggingFace transformers
library.
Args:
destination_path: Path where the model will be stored
bucket_uri: URI of the cloud directory containing the model
tokenizer_only: If True, only download tokenizer-related files
exclude_safetensors: If True, skip download of safetensor files
"""
try:
# Get the provider-specific filesystem
fs_class = CloudFileSystem._get_provider_fs(bucket_uri)
# Construct hash file URI
hash_uri = bucket_uri.rstrip("/") + "/hash"
# Try to download and read hash file
hash_content = fs_class.get_file(hash_uri, decode_as_utf_8=True)
if hash_content is not None:
f_hash = hash_content.strip()
logger.info(
f"Detected hash file in bucket {bucket_uri}. "
f"Using {f_hash} as the hash."
)
else:
f_hash = "0000000000000000000000000000000000000000"
logger.info(
f"Hash file does not exist in bucket {bucket_uri}. "
f"Using {f_hash} as the hash."
)
# Write hash to refs/main
main_dir = os.path.join(destination_path, "refs")
os.makedirs(main_dir, exist_ok=True)
with open(os.path.join(main_dir, "main"), "w") as f:
f.write(f_hash)
# Create destination directory
destination_dir = os.path.join(destination_path, "snapshots", f_hash)
os.makedirs(destination_dir, exist_ok=True)
logger.info(f'Downloading model files to directory "{destination_dir}".')
# Download files
tokenizer_file_substrings = (
["tokenizer", "config.json"] if tokenizer_only else []
)
safetensors_to_exclude = [".safetensors"] if exclude_safetensors else None
CloudFileSystem.download_files(
path=destination_dir,
bucket_uri=bucket_uri,
substrings_to_include=tokenizer_file_substrings,
suffixes_to_exclude=safetensors_to_exclude,
)
except Exception as e:
logger.exception(f"Error downloading model from {bucket_uri}: {e}")
raise
@staticmethod
def upload_files(
local_path: str,
bucket_uri: str,
) -> None:
"""Upload files to cloud storage.
Args:
local_path: The local path of the files to upload.
bucket_uri: The bucket uri to upload the files to, must start with
`s3://`, `gs://`, `abfss://`, or `azure://`.
"""
fs_class = CloudFileSystem._get_provider_fs(bucket_uri)
fs_class.upload_files(local_path, bucket_uri)
@staticmethod
def upload_model(
local_path: str,
bucket_uri: str,
) -> None:
"""Upload a model to cloud storage.
Args:
local_path: The local path of the model.
bucket_uri: The bucket uri to upload the model to, must start with `s3://` or `gs://`.
"""
try:
# If refs/main exists, upload as hash, and treat snapshots/<hash> as the model.
# Otherwise, this is a custom model, we do not assume folder hierarchy.
refs_main = Path(local_path, "refs", "main")
if refs_main.exists():
model_path = os.path.join(
local_path, "snapshots", refs_main.read_text().strip()
)
CloudFileSystem.upload_files(
local_path=model_path, bucket_uri=bucket_uri
)
CloudFileSystem.upload_files(
local_path=str(refs_main),
bucket_uri=os.path.join(bucket_uri, "hash"),
)
else:
CloudFileSystem.upload_files(
local_path=local_path, bucket_uri=bucket_uri
)
logger.info(f"Uploaded model files to {bucket_uri}.")
except Exception as e:
logger.exception(f"Error uploading model to {bucket_uri}: {e}")
raise
class _CacheEntry(NamedTuple):
value: Any
expire_time: Optional[float]
class CloudObjectCache:
"""A cache that works with both sync and async fetch functions.
The purpose of this data structure is to cache the result of a function call
usually used to fetch a value from a cloud object store.
The idea is this:
- Cloud operations are expensive
- In LoRA specifically, we would fetch remote storage to download the model weights
at each request.
- If the same model is requested many times, we don't want to inflate the time to first token.
- We control the cache via not only the least recently used eviction policy, but also
by expiring cache entries after a certain time.
- If the object is missing, we cache the missing status for a small duration while if
the object exists, we cache the object for a longer duration.
"""
def __init__(
self,
max_size: int,
fetch_fn: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]],
missing_expire_seconds: Optional[int] = None,
exists_expire_seconds: Optional[int] = None,
missing_object_value: Any = object(),
):
"""Initialize the cache.
Args:
max_size: Maximum number of items to store in cache
fetch_fn: Function to fetch values (can be sync or async)
missing_expire_seconds: How long to cache missing objects (None for no expiration)
exists_expire_seconds: How long to cache existing objects (None for no expiration)
"""
self._cache: Dict[str, _CacheEntry] = {}
self._max_size = max_size
self._fetch_fn = fetch_fn
self._missing_expire_seconds = missing_expire_seconds
self._exists_expire_seconds = exists_expire_seconds
self._is_async = inspect.iscoroutinefunction(fetch_fn) or (
callable(fetch_fn) and inspect.iscoroutinefunction(fetch_fn.__call__)
)
self._missing_object_value = missing_object_value
# Lock for thread-safe cache access
self._lock = asyncio.Lock()
async def aget(self, key: str) -> Any:
"""Async get value from cache or fetch it if needed."""
if not self._is_async:
raise ValueError("Cannot use async get() with sync fetch function")
async with self._lock:
value, should_fetch = self._check_cache(key)
if not should_fetch:
return value
# Fetch new value
value = await self._fetch_fn(key)
self._update_cache(key, value)
return value
def get(self, key: str) -> Any:
"""Sync get value from cache or fetch it if needed."""
if self._is_async:
raise ValueError("Cannot use sync get() with async fetch function")
# For sync access, we use a simple check-then-act pattern
# This is safe because sync functions are not used in async context
value, should_fetch = self._check_cache(key)
if not should_fetch:
return value
# Fetch new value
value = self._fetch_fn(key)
self._update_cache(key, value)
return value
def _check_cache(self, key: str) -> tuple[Any, bool]:
"""Check if key exists in cache and is valid.
Returns:
Tuple of (value, should_fetch)
where should_fetch is True if we need to fetch a new value
"""
now = time.monotonic()
if key in self._cache:
value, expire_time = self._cache[key]
if expire_time is None or now < expire_time:
return value, False
return None, True
def _update_cache(self, key: str, value: Any) -> None:
"""Update cache with new value."""
now = time.monotonic()
# Calculate expiration
expire_time = None
if (
self._missing_expire_seconds is not None
or self._exists_expire_seconds is not None
):
if value is self._missing_object_value:
expire_time = (
now + self._missing_expire_seconds
if self._missing_expire_seconds
else None
)
else:
expire_time = (
now + self._exists_expire_seconds
if self._exists_expire_seconds
else None
)
# Enforce size limit by removing oldest entry if needed
# This is an O(n) operation but it's fine since the cache size is usually small.
if len(self._cache) >= self._max_size:
oldest_key = min(
self._cache, key=lambda k: self._cache[k].expire_time or float("inf")
)
del self._cache[oldest_key]
self._cache[key] = _CacheEntry(value, expire_time)
def __len__(self) -> int:
return len(self._cache)
class CloudModelAccessor:
"""Unified accessor for models stored in cloud storage (S3 or GCS).
Args:
model_id: The model id to download or upload.
mirror_config: The mirror config for the model.
"""
def __init__(self, model_id: str, mirror_config: CloudMirrorConfig):
self.model_id = model_id
self.mirror_config = mirror_config
def _get_lock_path(self, suffix: str = "") -> Path:
return Path(
"~", f"{self.model_id.replace('/', '--')}{suffix}.lock"
).expanduser()
def _get_model_path(self) -> Path:
if Path(self.model_id).exists():
return Path(self.model_id)
# Delayed import to avoid circular dependencies
from transformers.utils.hub import TRANSFORMERS_CACHE
return Path(
TRANSFORMERS_CACHE, f"models--{self.model_id.replace('/', '--')}"
).expanduser()
def remote_object_cache(
max_size: int,
missing_expire_seconds: Optional[int] = None,
exists_expire_seconds: Optional[int] = None,
missing_object_value: Any = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""A decorator that provides async caching using CloudObjectCache.
This is a direct replacement for the remote_object_cache/cachetools combination,
using CloudObjectCache internally to maintain cache state.
Args:
max_size: Maximum number of items to store in cache
missing_expire_seconds: How long to cache missing objects
exists_expire_seconds: How long to cache existing objects
missing_object_value: Value to use for missing objects
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
# Create a single cache instance for this function
cache = CloudObjectCache(
max_size=max_size,
fetch_fn=func,
missing_expire_seconds=missing_expire_seconds,
exists_expire_seconds=exists_expire_seconds,
missing_object_value=missing_object_value,
)
async def wrapper(*args, **kwargs):
# Extract the key from either first positional arg or object_uri kwarg
key = args[0] if args else kwargs.get("object_uri")
return await cache.aget(key)
return wrapper
return decorator