Source code for ray.data.random_access_dataset

import bisect
import logging
import random
import time
from collections import defaultdict
import numpy as np
from typing import List, Any, Optional, TYPE_CHECKING

import ray
from ray.types import ObjectRef
from ray.data.block import BlockAccessor
from ray.data.context import DataContext, DEFAULT_SCHEDULING_STRATEGY
from ray.data._internal.remote_fn import cached_remote_fn
from ray.util.annotations import PublicAPI

try:
    import pyarrow as pa
except ImportError:
    pa = None

if TYPE_CHECKING:
    from ray.data import Dataset

logger = logging.getLogger(__name__)


[docs]@PublicAPI(stability="alpha") class RandomAccessDataset: """A class that provides distributed, random access to a Dataset. See: ``Dataset.to_random_access_dataset()``. """
[docs] def __init__( self, ds: "Dataset", key: str, num_workers: int, ): """Construct a RandomAccessDataset (internal API). The constructor is a private API. Use ``ds.to_random_access_dataset()`` to construct a RandomAccessDataset. """ schema = ds.schema(fetch_if_missing=True) if schema is None or isinstance(schema, type): raise ValueError("RandomAccessDataset only supports Arrow-format blocks.") start = time.perf_counter() logger.info("[setup] Indexing dataset by sort key.") sorted_ds = ds.sort(key) get_bounds = cached_remote_fn(_get_bounds) blocks = sorted_ds.get_internal_block_refs() logger.info("[setup] Computing block range bounds.") bounds = ray.get([get_bounds.remote(b, key) for b in blocks]) self._non_empty_blocks = [] self._lower_bound = None self._upper_bounds = [] for i, b in enumerate(bounds): if b: self._non_empty_blocks.append(blocks[i]) if self._lower_bound is None: self._lower_bound = b[0] self._upper_bounds.append(b[1]) logger.info("[setup] Creating {} random access workers.".format(num_workers)) ctx = DataContext.get_current() if ctx.scheduling_strategy != DEFAULT_SCHEDULING_STRATEGY: scheduling_strategy = ctx.scheduling_strategy else: scheduling_strategy = "SPREAD" self._workers = [ _RandomAccessWorker.options(scheduling_strategy=scheduling_strategy).remote( key ) for _ in range(num_workers) ] ( self._block_to_workers_map, self._worker_to_blocks_map, ) = self._compute_block_to_worker_assignments() logger.info( "[setup] Worker to blocks assignment: {}".format(self._worker_to_blocks_map) ) ray.get( [ w.assign_blocks.remote( { i: self._non_empty_blocks[i] for i in self._worker_to_blocks_map[w] } ) for w in self._workers ] ) logger.info("[setup] Finished assigning blocks to workers.") self._build_time = time.perf_counter() - start
def _compute_block_to_worker_assignments(self): # Return values. block_to_workers: dict[int, List["ray.ActorHandle"]] = defaultdict(list) worker_to_blocks: dict["ray.ActorHandle", List[int]] = defaultdict(list) # Aux data structures. loc_to_workers: dict[str, List["ray.ActorHandle"]] = defaultdict(list) locs = ray.get([w.ping.remote() for w in self._workers]) for i, loc in enumerate(locs): loc_to_workers[loc].append(self._workers[i]) block_locs = ray.experimental.get_object_locations(self._non_empty_blocks) # First, try to assign all blocks to all workers at its location. for block_idx, block in enumerate(self._non_empty_blocks): block_info = block_locs[block] locs = block_info.get("node_ids", []) for loc in locs: for worker in loc_to_workers[loc]: block_to_workers[block_idx].append(worker) worker_to_blocks[worker].append(block_idx) # Randomly assign any leftover blocks to at least one worker. # TODO: the load balancing here could be improved. for block_idx, block in enumerate(self._non_empty_blocks): if len(block_to_workers[block_idx]) == 0: worker = random.choice(self._workers) block_to_workers[block_idx].append(worker) worker_to_blocks[worker].append(block_idx) return block_to_workers, worker_to_blocks
[docs] def get_async(self, key: Any) -> ObjectRef[Any]: """Asynchronously finds the record for a single key. Args: key: The key of the record to find. Returns: ObjectRef containing the record (in pydict form), or None if not found. """ block_index = self._find_le(key) if block_index is None: return ray.put(None) return self._worker_for(block_index).get.remote(block_index, key)
[docs] def multiget(self, keys: List[Any]) -> List[Optional[Any]]: """Synchronously find the records for a list of keys. Args: keys: List of keys to find the records for. Returns: List of found records (in pydict form), or None for missing records. """ batches = defaultdict(list) for k in keys: batches[self._find_le(k)].append(k) futures = {} for index, keybatch in batches.items(): if index is None: continue fut = self._worker_for(index).multiget.remote( [index] * len(keybatch), keybatch ) futures[index] = fut results = {} for i, fut in futures.items(): keybatch = batches[i] values = ray.get(fut) for k, v in zip(keybatch, values): results[k] = v return [results.get(k) for k in keys]
[docs] def stats(self) -> str: """Returns a string containing access timing information.""" stats = ray.get([w.stats.remote() for w in self._workers]) total_time = sum(s["total_time"] for s in stats) accesses = [s["num_accesses"] for s in stats] blocks = [s["num_blocks"] for s in stats] msg = "RandomAccessDataset:\n" msg += "- Build time: {}s\n".format(round(self._build_time, 2)) msg += "- Num workers: {}\n".format(len(stats)) msg += "- Blocks per worker: {} min, {} max, {} mean\n".format( min(blocks), max(blocks), int(sum(blocks) / len(blocks)) ) msg += "- Accesses per worker: {} min, {} max, {} mean\n".format( min(accesses), max(accesses), int(sum(accesses) / len(accesses)) ) msg += "- Mean access time: {}us\n".format( int(total_time / (1 + sum(accesses)) * 1e6) ) return msg
def _worker_for(self, block_index: int): return random.choice(self._block_to_workers_map[block_index]) def _find_le(self, x: Any) -> int: i = bisect.bisect_left(self._upper_bounds, x) if i >= len(self._upper_bounds) or x < self._lower_bound: return None return i
@ray.remote(num_cpus=0) class _RandomAccessWorker: def __init__(self, key_field): self.blocks = None self.key_field = key_field self.num_accesses = 0 self.total_time = 0 def assign_blocks(self, block_ref_dict): self.blocks = {k: ray.get(ref) for k, ref in block_ref_dict.items()} def get(self, block_index, key): start = time.perf_counter() result = self._get(block_index, key) self.total_time += time.perf_counter() - start self.num_accesses += 1 return result def multiget(self, block_indices, keys): start = time.perf_counter() block = self.blocks[block_indices[0]] if len(set(block_indices)) == 1 and isinstance( self.blocks[block_indices[0]], pa.Table ): # Fast path: use np.searchsorted for vectorized search on a single block. # This is ~3x faster than the naive case. block = self.blocks[block_indices[0]] col = block[self.key_field] indices = np.searchsorted(col, keys) acc = BlockAccessor.for_block(block) result = [acc._get_row(i) for i in indices] # assert result == [self._get(i, k) for i, k in zip(block_indices, keys)] else: result = [self._get(i, k) for i, k in zip(block_indices, keys)] self.total_time += time.perf_counter() - start self.num_accesses += 1 return result def ping(self): return ray.get_runtime_context().get_node_id() def stats(self) -> dict: return { "num_blocks": len(self.blocks), "num_accesses": self.num_accesses, "total_time": self.total_time, } def _get(self, block_index, key): if block_index is None: return None block = self.blocks[block_index] column = block[self.key_field] if isinstance(block, pa.Table): column = _ArrowListWrapper(column) i = _binary_search_find(column, key) if i is None: return None acc = BlockAccessor.for_block(block) return acc._get_row(i) def _binary_search_find(column, x): i = bisect.bisect_left(column, x) if i != len(column) and column[i] == x: return i return None class _ArrowListWrapper: def __init__(self, arrow_col): self.arrow_col = arrow_col def __getitem__(self, i): return self.arrow_col[i].as_py() def __len__(self): return len(self.arrow_col) def _get_bounds(block, key): if len(block) == 0: return None b = (block[key][0], block[key][len(block) - 1]) if isinstance(block, pa.Table): b = (b[0].as_py(), b[1].as_py()) return b