Source code for ray.data.datasource.datasink

import itertools
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Iterable, List, Optional, TypeVar

import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
    import pyarrow as pa

logger = logging.getLogger(__name__)


WriteReturnType = TypeVar("WriteReturnType")
"""Generic type for the return value of `Datasink.write`."""


[docs] @dataclass @DeveloperAPI class WriteResult(Generic[WriteReturnType]): """Aggregated result of the Datasink write operations.""" # Total number of written rows. num_rows: int # Total size in bytes of written data. size_bytes: int # All returned values of `Datasink.write`. write_returns: List[WriteReturnType] @classmethod def combine(cls, *wrs: "WriteResult") -> "WriteResult": num_rows = sum(wr.num_rows for wr in wrs) size_bytes = sum(wr.size_bytes for wr in wrs) write_returns = list(itertools.chain(*[wr.write_returns for wr in wrs])) return WriteResult( num_rows=num_rows, size_bytes=size_bytes, write_returns=write_returns, )
[docs] @DeveloperAPI class Datasink(Generic[WriteReturnType]): """Interface for defining write-related logic. If you want to write data to something that isn't built-in, subclass this class and call :meth:`~ray.data.Dataset.write_datasink`. """
[docs] def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None: """Callback for when a write job starts. Use this method to perform setup for write tasks. For example, creating a staging bucket in S3. This is called on the driver when the first input bundle is ready, just before write tasks are submitted. The schema is extracted from the first input bundle, enabling schema-dependent initialization. Args: schema: The PyArrow schema of the data being written. This is automatically extracted from the first input bundle. May be None if the input data has no schema. """ pass
[docs] def write( self, blocks: Iterable[Block], ctx: TaskContext, ) -> WriteReturnType: """Write blocks. This is used by a single write task. Args: blocks: Generator of data blocks. ctx: ``TaskContext`` for the write task. Returns: Result of this write task. When the entire write operator finishes, All returned values will be passed as `WriteResult.write_returns` to `Datasink.on_write_complete`. """ raise NotImplementedError
[docs] def on_write_complete(self, write_result: WriteResult[WriteReturnType]): """Callback for when a write job completes. This can be used to `commit` a write output. This method must succeed prior to ``write_datasink()`` returning to the user. If this method fails, then ``on_write_failed()`` is called. Args: write_result: Aggregated result of the Write operator, containing write results and stats. """ pass
[docs] def on_write_failed(self, error: Exception) -> None: """Callback for when a write job fails. This is called on a best-effort basis on write failures. Args: error: The first error encountered. """ pass
[docs] def get_name(self) -> str: """Return a human-readable name for this datasink. This is used as the names of the write tasks. """ name = type(self).__name__ datasink_suffix = "Datasink" if name.startswith("_"): name = name[1:] if name.endswith(datasink_suffix): name = name[: -len(datasink_suffix)] return name
@property def supports_distributed_writes(self) -> bool: """If ``False``, only launch write tasks on the driver's node.""" return True @property def min_rows_per_write(self) -> Optional[int]: """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call. If ``None``, Ray Data passes a system-chosen number of rows. """ return None
@DeveloperAPI class DummyOutputDatasink(Datasink[None]): """An example implementation of a writable datasource for testing. Examples: >>> import ray >>> from ray.data.datasource import DummyOutputDatasink >>> output = DummyOutputDatasink() >>> ray.data.range(10).write_datasink(output) >>> assert output.num_ok == 1 """ def __init__(self): ctx = ray.data.DataContext.get_current() # Setup a dummy actor to send the data. In a real datasource, write # tasks would send data to an external system instead of a Ray actor. @ray.remote(scheduling_strategy=ctx.scheduling_strategy) class DataSink: def __init__(self): self.rows_written = 0 self.enabled = True def write(self, block: Block) -> None: block = BlockAccessor.for_block(block) self.rows_written += block.num_rows() def get_rows_written(self): return self.rows_written self.data_sink = DataSink.remote() self.num_ok = 0 self.num_failed = 0 self.enabled = True def write( self, blocks: Iterable[Block], ctx: TaskContext, ) -> None: tasks = [] if not self.enabled: raise ValueError("disabled") for b in blocks: tasks.append(self.data_sink.write.remote(b)) ray.get(tasks) def on_write_complete(self, write_result: WriteResult[None]): self.num_ok += 1 def on_write_failed(self, error: Exception) -> None: self.num_failed += 1 def _gen_datasink_write_result( write_result_blocks: List[Block], ) -> WriteResult: import pandas as pd assert all( isinstance(block, pd.DataFrame) and len(block) == 1 for block in write_result_blocks ) total_num_rows = sum(result["num_rows"].sum() for result in write_result_blocks) total_size_bytes = sum(result["size_bytes"].sum() for result in write_result_blocks) write_returns = [result["write_return"][0] for result in write_result_blocks] return WriteResult(total_num_rows, total_size_bytes, write_returns)