from typing import (
List,
Optional,
Tuple,
Union,
Dict,
Any,
Callable,
Awaitable,
TypeVar,
NamedTuple,
)
from pydantic import Field, field_validator
import os
import time
import inspect
import asyncio
# Use pyarrow for cloud storage access
import pyarrow.fs as pa_fs
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.common.base_pydantic import BaseModelExtended
T = TypeVar("T")
logger = get_logger(__name__)
class ExtraFiles(BaseModelExtended):
bucket_uri: str
destination_path: str
class CloudMirrorConfig(BaseModelExtended):
"""Unified mirror config for cloud storage (S3 or GCS).
Args:
bucket_uri: URI of the bucket (s3:// or gs://)
extra_files: Additional files to download
"""
bucket_uri: Optional[str] = None
extra_files: List[ExtraFiles] = Field(default_factory=list)
@property
def storage_type(self) -> str:
"""Returns the storage type ('s3' or 'gcs') based on the URI prefix."""
if self.bucket_uri is None:
return None
elif self.bucket_uri.startswith("s3://"):
return "s3"
elif self.bucket_uri.startswith("gs://"):
return "gcs"
return None
class LoraMirrorConfig(BaseModelExtended):
lora_model_id: str
bucket_uri: str
max_total_tokens: Optional[int]
sync_args: Optional[List[str]] = None
@field_validator("bucket_uri")
@classmethod
def check_uri_format(cls, value):
if value is None:
return value
if not value.startswith("s3://") and not value.startswith("gs://"):
raise ValueError(
f'Got invalid value "{value}" for bucket_uri. '
'Expected a URI that starts with "s3://" or "gs://".'
)
return value
@property
def _bucket_name_and_path(self) -> str:
for prefix in ["s3://", "gs://"]:
if self.bucket_uri.startswith(prefix):
return self.bucket_uri[len(prefix) :]
return self.bucket_uri
@property
def bucket_name(self) -> str:
return self._bucket_name_and_path.split("/")[0]
@property
def bucket_path(self) -> str:
return "/".join(self._bucket_name_and_path.split("/")[1:])
class CloudFileSystem:
"""A unified interface for cloud file system operations using PyArrow.
This class provides a simple interface for common operations on cloud storage
systems (S3, GCS) using PyArrow's filesystem interface.
"""
@staticmethod
def get_fs_and_path(object_uri: str) -> Tuple[pa_fs.FileSystem, str]:
"""Get the appropriate filesystem and path from a URI.
Args:
object_uri: URI of the file (s3:// or gs://)
If URI contains 'anonymous@', anonymous access is used.
Example: s3://anonymous@bucket/path
Returns:
Tuple of (filesystem, path)
"""
anonymous = False
# Check for anonymous access pattern
# e.g. s3://anonymous@bucket/path
if "@" in object_uri:
parts = object_uri.split("@", 1)
# Check if the first part ends with "anonymous"
if parts[0].endswith("anonymous"):
anonymous = True
# Remove the anonymous@ part, keeping the scheme
scheme = parts[0].split("://")[0]
object_uri = f"{scheme}://{parts[1]}"
if object_uri.startswith("s3://"):
fs = pa_fs.S3FileSystem(anonymous=anonymous)
path = object_uri[5:] # Remove "s3://"
elif object_uri.startswith("gs://"):
fs = pa_fs.GcsFileSystem(anonymous=anonymous)
path = object_uri[5:] # Remove "gs://"
else:
raise ValueError(f"Unsupported URI scheme: {object_uri}")
return fs, path
@staticmethod
def get_file(
object_uri: str, decode_as_utf_8: bool = True
) -> Optional[Union[str, bytes]]:
"""Download a file from cloud storage into memory.
Args:
object_uri: URI of the file (s3:// or gs://)
decode_as_utf_8: If True, decode the file as UTF-8
Returns:
File contents as string or bytes, or None if file doesn't exist
"""
try:
fs, path = CloudFileSystem.get_fs_and_path(object_uri)
# Check if file exists
if not fs.get_file_info(path).type == pa_fs.FileType.File:
logger.info(f"URI {object_uri} does not exist.")
return None
# Read file
with fs.open_input_file(path) as f:
body = f.read()
if decode_as_utf_8:
body = body.decode("utf-8")
return body
except Exception as e:
logger.info(f"Error reading {object_uri}: {e}")
return None
@staticmethod
def list_subfolders(folder_uri: str) -> List[str]:
"""List the immediate subfolders in a cloud directory.
Args:
folder_uri: URI of the directory (s3:// or gs://)
Returns:
List of subfolder names (without trailing slashes)
"""
# Ensure that the folder_uri has a trailing slash.
folder_uri = f"{folder_uri.rstrip('/')}/"
try:
fs, path = CloudFileSystem.get_fs_and_path(folder_uri)
# List directory contents
file_infos = fs.get_file_info(pa_fs.FileSelector(path, recursive=False))
# Filter for directories and extract subfolder names
subfolders = []
for file_info in file_infos:
if file_info.type == pa_fs.FileType.Directory:
# Extract just the subfolder name without the full path
subfolder = os.path.basename(file_info.path.rstrip("/"))
subfolders.append(subfolder)
return subfolders
except Exception as e:
logger.info(f"Error listing subfolders in {folder_uri}: {e}")
return []
@staticmethod
def download_files(
path: str,
bucket_uri: str,
substrings_to_include: Optional[List[str]] = None,
) -> None:
"""Download files from cloud storage to a local directory.
Args:
path: Local directory where files will be downloaded
bucket_uri: URI of cloud directory
substrings_to_include: Only include files containing these substrings
"""
try:
fs, source_path = CloudFileSystem.get_fs_and_path(bucket_uri)
# Ensure the destination directory exists
os.makedirs(path, exist_ok=True)
# List all files in the bucket
file_selector = pa_fs.FileSelector(source_path, recursive=True)
file_infos = fs.get_file_info(file_selector)
# Download each file
for file_info in file_infos:
if file_info.type != pa_fs.FileType.File:
continue
# Get relative path from source prefix
rel_path = file_info.path[len(source_path) :].lstrip("/")
# Check if file matches substring filters
if substrings_to_include:
if not any(
substring in rel_path for substring in substrings_to_include
):
continue
# Create destination directory if needed
if "/" in rel_path:
dest_dir = os.path.join(path, os.path.dirname(rel_path))
os.makedirs(dest_dir, exist_ok=True)
# Download the file
dest_path = os.path.join(path, rel_path)
with fs.open_input_file(file_info.path) as source_file:
with open(dest_path, "wb") as dest_file:
dest_file.write(source_file.read())
except Exception as e:
logger.exception(f"Error downloading files from {bucket_uri}: {e}")
raise
@staticmethod
def download_model(
destination_path: str, bucket_uri: str, tokenizer_only: bool
) -> None:
"""Download a model from cloud storage.
This downloads a model in the format expected by the HuggingFace transformers
library.
Args:
destination_path: Path where the model will be stored
bucket_uri: URI of the cloud directory containing the model
tokenizer_only: If True, only download tokenizer-related files
"""
try:
fs, source_path = CloudFileSystem.get_fs_and_path(bucket_uri)
# Check for hash file
hash_path = os.path.join(source_path, "hash")
hash_info = fs.get_file_info(hash_path)
if hash_info.type == pa_fs.FileType.File:
# Download and read hash file
with fs.open_input_file(hash_path) as f:
f_hash = f.read().decode("utf-8").strip()
logger.info(
f"Detected hash file in bucket {bucket_uri}. "
f"Using {f_hash} as the hash."
)
else:
f_hash = "0000000000000000000000000000000000000000"
logger.warning(
f"Hash file does not exist in bucket {bucket_uri}. "
f"Using {f_hash} as the hash."
)
# Write hash to refs/main
main_dir = os.path.join(destination_path, "refs")
os.makedirs(main_dir, exist_ok=True)
with open(os.path.join(main_dir, "main"), "w") as f:
f.write(f_hash)
# Create destination directory
destination_dir = os.path.join(destination_path, "snapshots", f_hash)
os.makedirs(destination_dir, exist_ok=True)
logger.info(f'Downloading model files to directory "{destination_dir}".')
# Download files
tokenizer_file_substrings = (
["tokenizer", "config.json"] if tokenizer_only else []
)
CloudFileSystem.download_files(
path=destination_dir,
bucket_uri=bucket_uri,
substrings_to_include=tokenizer_file_substrings,
)
except Exception as e:
logger.exception(f"Error downloading model from {bucket_uri}: {e}")
raise
class _CacheEntry(NamedTuple):
value: Any
expire_time: Optional[float]
class CloudObjectCache:
"""A cache that works with both sync and async fetch functions.
The purpose of this data structure is to cache the result of a function call
usually used to fetch a value from a cloud object store.
The idea is this:
- Cloud operations are expensive
- In LoRA specifically, we would fetch remote storage to download the model weights
at each request.
- If the same model is requested many times, we don't want to inflate the time to first token.
- We control the cache via not only the least recently used eviction policy, but also
by expiring cache entries after a certain time.
- If the object is missing, we cache the missing status for a small duration while if
the object exists, we cache the object for a longer duration.
"""
def __init__(
self,
max_size: int,
fetch_fn: Union[Callable[[str], Any], Callable[[str], Awaitable[Any]]],
missing_expire_seconds: Optional[int] = None,
exists_expire_seconds: Optional[int] = None,
missing_object_value: Any = object(),
):
"""Initialize the cache.
Args:
max_size: Maximum number of items to store in cache
fetch_fn: Function to fetch values (can be sync or async)
missing_expire_seconds: How long to cache missing objects (None for no expiration)
exists_expire_seconds: How long to cache existing objects (None for no expiration)
"""
self._cache: Dict[str, _CacheEntry] = {}
self._max_size = max_size
self._fetch_fn = fetch_fn
self._missing_expire_seconds = missing_expire_seconds
self._exists_expire_seconds = exists_expire_seconds
self._is_async = inspect.iscoroutinefunction(fetch_fn) or (
callable(fetch_fn) and inspect.iscoroutinefunction(fetch_fn.__call__)
)
self._missing_object_value = missing_object_value
# Lock for thread-safe cache access
self._lock = asyncio.Lock()
async def aget(self, key: str) -> Any:
"""Async get value from cache or fetch it if needed."""
if not self._is_async:
raise ValueError("Cannot use async get() with sync fetch function")
async with self._lock:
value, should_fetch = self._check_cache(key)
if not should_fetch:
return value
# Fetch new value
value = await self._fetch_fn(key)
self._update_cache(key, value)
return value
def get(self, key: str) -> Any:
"""Sync get value from cache or fetch it if needed."""
if self._is_async:
raise ValueError("Cannot use sync get() with async fetch function")
# For sync access, we use a simple check-then-act pattern
# This is safe because sync functions are not used in async context
value, should_fetch = self._check_cache(key)
if not should_fetch:
return value
# Fetch new value
value = self._fetch_fn(key)
self._update_cache(key, value)
return value
def _check_cache(self, key: str) -> tuple[Any, bool]:
"""Check if key exists in cache and is valid.
Returns:
Tuple of (value, should_fetch)
where should_fetch is True if we need to fetch a new value
"""
now = time.monotonic()
if key in self._cache:
value, expire_time = self._cache[key]
if expire_time is None or now < expire_time:
return value, False
return None, True
def _update_cache(self, key: str, value: Any) -> None:
"""Update cache with new value."""
now = time.monotonic()
# Calculate expiration
expire_time = None
if (
self._missing_expire_seconds is not None
or self._exists_expire_seconds is not None
):
if value is self._missing_object_value:
expire_time = (
now + self._missing_expire_seconds
if self._missing_expire_seconds
else None
)
else:
expire_time = (
now + self._exists_expire_seconds
if self._exists_expire_seconds
else None
)
# Enforce size limit by removing oldest entry if needed
# This is an O(n) operation but it's fine since the cache size is usually small.
if len(self._cache) >= self._max_size:
oldest_key = min(
self._cache, key=lambda k: self._cache[k].expire_time or float("inf")
)
del self._cache[oldest_key]
self._cache[key] = _CacheEntry(value, expire_time)
def __len__(self) -> int:
return len(self._cache)
def remote_object_cache(
max_size: int,
missing_expire_seconds: Optional[int] = None,
exists_expire_seconds: Optional[int] = None,
missing_object_value: Any = None,
) -> Callable[[Callable[..., T]], Callable[..., T]]:
"""A decorator that provides async caching using CloudObjectCache.
This is a direct replacement for the remote_object_cache/cachetools combination,
using CloudObjectCache internally to maintain cache state.
Args:
max_size: Maximum number of items to store in cache
missing_expire_seconds: How long to cache missing objects
exists_expire_seconds: How long to cache existing objects
missing_object_value: Value to use for missing objects
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
# Create a single cache instance for this function
cache = CloudObjectCache(
max_size=max_size,
fetch_fn=func,
missing_expire_seconds=missing_expire_seconds,
exists_expire_seconds=exists_expire_seconds,
missing_object_value=missing_object_value,
)
async def wrapper(*args, **kwargs):
# Extract the key from either first positional arg or object_uri kwarg
key = args[0] if args else kwargs.get("object_uri")
return await cache.aget(key)
return wrapper
return decorator