Source code for ray.data.datasource.datasink
import logging
from dataclasses import dataclass, fields
from typing import Iterable, List, Optional
import ray
from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block, BlockAccessor
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@dataclass
@DeveloperAPI
class WriteResult:
"""Result of a write operation, containing stats/metrics
on the written data.
Attributes:
total_num_rows: The total number of rows written.
total_size_bytes: The total size of the written data in bytes.
"""
num_rows: int = 0
size_bytes: int = 0
@staticmethod
def aggregate_write_results(write_results: List["WriteResult"]) -> "WriteResult":
"""Aggregate a list of write results.
Args:
write_results: A list of write results.
Returns:
A single write result that aggregates the input results.
"""
total_num_rows = 0
total_size_bytes = 0
for write_result in write_results:
total_num_rows += write_result.num_rows
total_size_bytes += write_result.size_bytes
return WriteResult(
num_rows=total_num_rows,
size_bytes=total_size_bytes,
)
[docs]
@DeveloperAPI
class Datasink:
"""Interface for defining write-related logic.
If you want to write data to something that isn't built-in, subclass this class
and call :meth:`~ray.data.Dataset.write_datasink`.
"""
[docs]
def on_write_start(self) -> None:
"""Callback for when a write job starts.
Use this method to perform setup for write tasks. For example, creating a
staging bucket in S3.
"""
pass
[docs]
def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
"""Write blocks. This is used by a single write task.
Args:
blocks: Generator of data blocks.
ctx: ``TaskContext`` for the write task.
"""
raise NotImplementedError
[docs]
def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
"""Callback for when a write job completes.
This can be used to "commit" a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.
Args:
write_result_blocks: The blocks resulting from executing
the Write operator, containing write results and stats.
Returns:
A ``WriteResult`` object containing the aggregated stats of all
the input write results.
"""
write_results = [
result["write_result"].iloc[0] for result in write_result_blocks
]
aggregated_write_results = WriteResult.aggregate_write_results(write_results)
aggregated_results_str = ""
for k in fields(aggregated_write_results.__class__):
v = getattr(aggregated_write_results, k.name)
aggregated_results_str += f"\t- {k.name}: {v}\n"
logger.info(
f"Write operation succeeded. Aggregated write results:\n"
f"{aggregated_results_str}"
)
return aggregated_write_results
[docs]
def on_write_failed(self, error: Exception) -> None:
"""Callback for when a write job fails.
This is called on a best-effort basis on write failures.
Args:
error: The first error encountered.
"""
pass
[docs]
def get_name(self) -> str:
"""Return a human-readable name for this datasink.
This is used as the names of the write tasks.
"""
name = type(self).__name__
datasink_suffix = "Datasink"
if name.startswith("_"):
name = name[1:]
if name.endswith(datasink_suffix):
name = name[: -len(datasink_suffix)]
return name
@property
def supports_distributed_writes(self) -> bool:
"""If ``False``, only launch write tasks on the driver's node."""
return True
@property
def num_rows_per_write(self) -> Optional[int]:
"""The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call.
If ``None``, Ray Data passes a system-chosen number of rows.
"""
return None
@DeveloperAPI
class DummyOutputDatasink(Datasink):
"""An example implementation of a writable datasource for testing.
Examples:
>>> import ray
>>> from ray.data.datasource import DummyOutputDatasink
>>> output = DummyOutputDatasink()
>>> ray.data.range(10).write_datasink(output)
>>> assert output.num_ok == 1
"""
def __init__(self):
ctx = ray.data.DataContext.get_current()
# Setup a dummy actor to send the data. In a real datasource, write
# tasks would send data to an external system instead of a Ray actor.
@ray.remote(scheduling_strategy=ctx.scheduling_strategy)
class DataSink:
def __init__(self):
self.rows_written = 0
self.enabled = True
def write(self, block: Block) -> None:
block = BlockAccessor.for_block(block)
self.rows_written += block.num_rows()
def get_rows_written(self):
return self.rows_written
self.data_sink = DataSink.remote()
self.num_ok = 0
self.num_failed = 0
self.enabled = True
def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
tasks = []
if not self.enabled:
raise ValueError("disabled")
for b in blocks:
tasks.append(self.data_sink.write.remote(b))
ray.get(tasks)
def on_write_complete(self, write_result_blocks: List[Block]) -> WriteResult:
self.num_ok += 1
aggregated_results = super().on_write_complete(write_result_blocks)
return aggregated_results
def on_write_failed(self, error: Exception) -> None:
self.num_failed += 1