"""Catalog connectors for Ray Data readers.
A :class:`Catalog` resolves a table name into a readable source (location +
credentials) for a reader such as :func:`ray.data.read_delta`,
:func:`ray.data.read_parquet`, or :func:`ray.data.read_iceberg`.
"""
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import KW_ONLY, dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib.parse import urljoin
from ray.util.annotations import DeveloperAPI, PublicAPI
if TYPE_CHECKING:
import pyarrow.fs
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.catalog import (
AwsCredentials,
AzureUserDelegationSas,
GcpOauthToken,
GenerateTemporaryTableCredentialResponse,
TableInfo,
)
from ray.data._internal.datasource.databricks_credentials import (
DatabricksCredentialProvider,
)
logger = logging.getLogger(__name__)
_DELTA_UNIFORM_FORMATS_PROPERTY = "delta.universalFormat.enabledFormats"
# Environment-variable names the underlying readers (pyarrow / deltalake's
# object_store) pick up vended credentials from.
_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID"
_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"
_AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN"
_AWS_REGION = "AWS_REGION"
_AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION"
_AZURE_STORAGE_SAS_TOKEN = "AZURE_STORAGE_SAS_TOKEN"
def _normalize_host(host: str) -> str:
host = host.rstrip("/")
if not host.startswith(("http://", "https://")):
host = f"https://{host}"
return host
@DeveloperAPI
@dataclass
class ResolvedSource:
"""The output of :meth:`Catalog.resolve` — location/credentials for a reader.
A reader consumes only the fields it understands:
* ``read_delta``: ``path`` + (``storage_options`` and/or ``filesystem``)
* ``read_parquet``: ``path`` + ``filesystem``
* ``read_iceberg``: ``catalog_kwargs`` + ``table_identifier``
Unused fields are ``None``.
"""
path: Optional[str] = None
filesystem: Optional["pyarrow.fs.FileSystem"] = None
storage_options: Optional[Dict[str, Any]] = None
catalog_kwargs: Optional[Dict[str, Any]] = None
# Identifier the reader should address the table by, if the catalog rewrites
# it (e.g. Iceberg REST scopes the warehouse to the catalog, so the table is
# addressed as ``schema.table`` rather than ``catalog.schema.table``).
table_identifier: Optional[str] = None
data_format: Optional[ReaderFormat] = None # hint, e.g. ReaderFormat.DELTA
[docs]
@PublicAPI(stability="alpha")
class Catalog(ABC):
"""A directory service that resolves a table name to a readable source."""
[docs]
@abstractmethod
def resolve(self, table: str, *, reader: ReaderFormat) -> ResolvedSource:
"""Resolve ``table`` for the given ``reader``."""
...
[docs]
@PublicAPI(stability="alpha")
@dataclass(frozen=True)
class DatabricksUnityCatalog(Catalog):
"""Databricks Unity Catalog connector.
For Delta and Parquet tables this performs Unity Catalog credential vending
(temporary, least-privilege cloud credentials). For Iceberg tables it
returns configuration pointing PyIceberg at Unity Catalog's Iceberg REST
catalog endpoint.
Args:
url: Databricks workspace URL (e.g.
``"https://dbc-XXXX.cloud.databricks.com"``). Required unless
``credential_provider`` is given.
token: Databricks Personal Access Token with ``EXTERNAL USE SCHEMA``
permission. Required unless ``credential_provider`` is given.
credential_provider: A custom
:class:`~ray.data._internal.datasource.databricks_credentials.DatabricksCredentialProvider`.
If provided, ``url``/``token`` are ignored.
region: AWS region for S3 access (e.g. ``"us-west-2"``). Required for
AWS-backed tables; not needed for Azure/GCP.
Example:
>>> import ray
>>> catalog = ray.data.DatabricksUnityCatalog( # doctest: +SKIP
... url="https://dbc-XXXX.cloud.databricks.com",
... token="dapi...",
... region="us-west-2",
... )
>>> ds = ray.data.read_delta( # doctest: +SKIP
... "main.sales.transactions", catalog=catalog
... )
"""
_: KW_ONLY
url: Optional[str] = None
# `repr=False` keeps the token/provider out of the auto-generated repr.
token: Optional[str] = field(default=None, repr=False)
credential_provider: Optional["DatabricksCredentialProvider"] = field(
default=None, repr=False
)
region: Optional[str] = None
# Derived in __post_init__; declared (init=False) so type checkers know the
# attributes exist, and excluded from repr/eq.
_provider: "DatabricksCredentialProvider" = field(
init=False, repr=False, compare=False
)
_base_url: str = field(init=False, repr=False, compare=False)
def __post_init__(self):
from ray.data._internal.datasource.databricks_credentials import (
UnityCatalogCredentialConfig,
resolve_credential_provider,
)
# Derived (not init args); `object.__setattr__` is how a frozen dataclass
# assigns inside __post_init__.
provider = resolve_credential_provider(
UnityCatalogCredentialConfig(
credential_provider=self.credential_provider,
url=self.url,
token=self.token,
)
)
object.__setattr__(self, "_provider", provider)
object.__setattr__(self, "_base_url", _normalize_host(provider.get_host()))
# ---- Catalog interface -------------------------------------------------
def resolve(self, table: str, *, reader: ReaderFormat) -> ResolvedSource:
assert reader is not None and isinstance(reader, ReaderFormat)
if reader is ReaderFormat.ICEBERG:
return self._resolve_iceberg(table)
if reader in (ReaderFormat.DELTA, ReaderFormat.PARQUET):
return self._resolve_storage(table, reader)
# Reached only if a new ReaderFormat is added without handling here.
raise ValueError(f"DatabricksUnityCatalog does not support format={reader!r}")
# ---- storage-credential vending (delta / parquet) ----------------------
def _resolve_storage(self, table: str, reader: ReaderFormat) -> ResolvedSource:
table_info = self._get_table_info(table)
creds, table_url = self._get_creds(table_info.table_id)
# Some reads need an explicit pyarrow filesystem:
# - AWS Delta: the vended session token isn't reliably propagated through
# `DeltaTable.to_pyarrow_dataset`'s auto-built filesystem.
# - GCP Parquet: a bare OAuth token has no env var pyarrow auto-reads,
# so the data scan needs an explicit GcsFileSystem.
filesystem = None
if creds.aws_temp_credentials is not None and reader is ReaderFormat.DELTA:
filesystem = self._build_s3_filesystem(creds.aws_temp_credentials)
elif creds.gcp_oauth_token is not None:
if reader is ReaderFormat.DELTA:
# Unity Catalog vends a GCP OAuth token, but deltalake's
# object_store (<=0.13.x, bundled in deltalake<=1.6.1) only
# accepts service-account-key auth for GCS -- it has no
# bearer/OAuth-token config key -- so the Delta transaction-log
# read can't use the vended token and silently falls back to GCE
# metadata-server auth.
raise RuntimeError(
"Reading a GCP-backed Delta table via Unity Catalog "
"credential vending is not supported as deltalake "
"does not have the required object_store version."
)
filesystem = self._build_gcs_filesystem(
creds.gcp_oauth_token, creds.expiration_time
)
# Deliver vended credentials via environment variables. This is the
# mechanism the underlying libraries read uniformly: pyarrow (Parquet
# data, and S3/Azure/GCS auto-filesystems) and deltalake's object_store
# (the Delta transaction *log* read in `DeltaTable(...)`, which neither
# a pyarrow `filesystem` nor `storage_options` keyed for pyarrow would
# satisfy). See `_apply_env` for the worker-propagation note.
#
# TODO: remove the env-var + ray.init mechanism once credential vending
# is performed inside the read tasks themselves (worker-side).
self._apply_env(self._creds_to_env(creds))
return ResolvedSource(
path=table_url,
filesystem=filesystem,
data_format=self._infer_format(table_info, table_url),
)
# ---- iceberg REST catalog ---------------------------------------------
def _resolve_iceberg(self, table: str) -> ResolvedSource:
# PyIceberg speaks the Iceberg REST protocol; Unity Catalog implements
# it and vends data-file credentials via the access-delegation header.
# No manual S3/ADLS/GCS keys are needed here.
#
# The REST catalog is scoped to a single UC catalog via `warehouse`, so
# the table is addressed by `schema.table` (the catalog prefix would
# otherwise be double-applied, e.g. `tmp.tmp.schema.table`).
catalog_name, _, namespace_table = table.partition(".")
return ResolvedSource(
table_identifier=namespace_table,
catalog_kwargs={
"type": "rest",
"uri": urljoin(self._base_url, "/api/2.1/unity-catalog/iceberg-rest"),
"warehouse": catalog_name,
"token": self._provider.get_token(),
"header.X-Iceberg-Access-Delegation": "vended-credentials",
},
data_format=ReaderFormat.ICEBERG,
)
# ---- Unity Catalog SDK helpers ----------------------------------------
def _workspace_client(self) -> "WorkspaceClient":
from databricks.sdk import WorkspaceClient
return WorkspaceClient(host=self._base_url, token=self._provider.get_token())
def _call_with_token_refresh(self, call: Callable) -> Any:
"""Run ``call(workspace_client)``, retrying once on 401.
Mirrors the previous ``request_with_401_retry`` behavior: on an
authentication failure, invalidate the credential provider (so the next
``get_token()`` returns a fresh token) and retry once with a new client.
Matters for refreshable providers; a no-op for static PATs.
"""
from databricks.sdk.errors import Unauthenticated
try:
return call(self._workspace_client())
except Unauthenticated:
logger.info("Received 401 from Unity Catalog; refreshing credentials.")
self._provider.invalidate()
return call(self._workspace_client())
def _get_table_info(self, table: str) -> "TableInfo":
return self._call_with_token_refresh(lambda w: w.tables.get(full_name=table))
def _get_creds(
self, table_id: Optional[str]
) -> Tuple["GenerateTemporaryTableCredentialResponse", str]:
from databricks.sdk.service.catalog import TableOperation
assert table_id is not None
creds = self._call_with_token_refresh(
lambda w: w.temporary_table_credentials.generate_temporary_table_credentials(
table_id=table_id, operation=TableOperation.READ
)
)
return creds, creds.url
@staticmethod
def _infer_format(
table_info: "TableInfo", table_url: str
) -> Optional[ReaderFormat]:
"""Best-effort format hint from table metadata or file extension."""
from databricks.sdk.service.catalog import DataSourceFormat
dsf = table_info.data_source_format
if dsf == DataSourceFormat.DELTA:
uniform = (table_info.properties or {}).get(
_DELTA_UNIFORM_FORMATS_PROPERTY, ""
)
if "iceberg" in uniform.lower():
return ReaderFormat.ICEBERG
return ReaderFormat.DELTA
elif dsf == DataSourceFormat.PARQUET:
return ReaderFormat.PARQUET
storage_loc = table_info.storage_location or table_url
if storage_loc:
ext = os.path.splitext(storage_loc)[-1].replace(".", "").lower()
if ext in (ReaderFormat.DELTA.value, ReaderFormat.PARQUET.value):
return ReaderFormat(ext)
return None
def _creds_to_env(
self, creds: "GenerateTemporaryTableCredentialResponse"
) -> Dict[str, Optional[str]]:
"""Translate vended credentials into environment variables."""
if creds.aws_temp_credentials is not None:
aws = creds.aws_temp_credentials
env = {
_AWS_ACCESS_KEY_ID: aws.access_key_id,
_AWS_SECRET_ACCESS_KEY: aws.secret_access_key,
_AWS_SESSION_TOKEN: aws.session_token,
}
if self.region:
env[_AWS_REGION] = self.region
env[_AWS_DEFAULT_REGION] = self.region
return env
if creds.azure_user_delegation_sas is not None:
return self._parse_azure_creds(creds.azure_user_delegation_sas)
if creds.gcp_oauth_token is not None:
# A bare GCP OAuth token has no env var pyarrow/deltalake auto-read;
# it's delivered via an explicit GcsFileSystem (data scan) and via
# `storage_options` (Delta log read) in `_resolve_storage` instead.
return {}
raise ValueError("No known credential type found in Databricks UC response.")
@staticmethod
def _apply_env(env_vars: Dict[str, Optional[str]]) -> None:
"""Set vended credentials in the environment and propagate to workers.
Credentials are set on the driver's ``os.environ`` and, if Ray has not
been initialized yet, into the cluster ``runtime_env`` so read tasks on
workers inherit them. If Ray is already running we cannot retroactively
amend its ``runtime_env``; driver-side env still covers driver reads
(e.g. the Delta log) and single-node execution.
TODO: remove once credential vending happens inside the read tasks.
"""
import ray
if not env_vars:
return
for k, v in env_vars.items():
if v:
os.environ[k] = v
if not ray.is_initialized():
ray.init(runtime_env={"env_vars": dict(env_vars)})
def _build_s3_filesystem(self, aws: "AwsCredentials") -> "pyarrow.fs.FileSystem":
if not self.region:
raise ValueError(
"The 'region' parameter is required for AWS S3 access. "
"Please specify the AWS region (e.g., region='us-west-2')."
)
import pyarrow.fs as pafs
return pafs.S3FileSystem(
access_key=aws.access_key_id,
secret_key=aws.secret_access_key,
session_token=aws.session_token,
region=self.region,
)
@staticmethod
def _build_gcs_filesystem(
gcp: "GcpOauthToken", expiration_time: Optional[int]
) -> "pyarrow.fs.FileSystem":
import pyarrow.fs as pafs
if expiration_time is None:
# pyarrow requires an expiration alongside an access token.
raise ValueError(
"GCP credential vending did not return an expiration_time."
)
expiration = datetime.fromtimestamp(expiration_time / 1000, tz=timezone.utc)
return pafs.GcsFileSystem(
access_token=gcp.oauth_token,
credential_token_expiration=expiration,
)
@staticmethod
def _parse_azure_creds(sas: "AzureUserDelegationSas") -> Dict[str, Optional[str]]:
sas_token = sas.sas_token
if sas_token and sas_token.startswith("?"):
sas_token = sas_token[1:]
if not sas_token:
raise ValueError("Azure UC credentials missing a SAS token.")
creds: Dict[str, Optional[str]] = {_AZURE_STORAGE_SAS_TOKEN: sas_token}
return creds