Dataset.split(n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None) List[ray.data.dataset.Dataset[ray.data.block.T]][source]#

Split the dataset into n disjoint pieces.

This returns a list of sub-datasets that can be passed to Ray tasks and actors and used to read the dataset records in parallel.


>>> import ray
>>> ds = ray.data.range(100) 
>>> workers = ... 
>>> # Split up a dataset to process over `n` worker actors.
>>> shards = ds.split(len(workers), locality_hints=workers) 
>>> for shard, worker in zip(shards, workers): 
...     worker.consume.remote(shard) 

Time complexity: O(1)

See also: Dataset.split_at_indices, Dataset.split_proportionately

  • n – Number of child datasets to return.

  • equal – Whether to guarantee each split has an equal number of records. This may drop records if they cannot be divided equally among the splits.

  • locality_hints – [Experimental] A list of Ray actor handles of size n. The system will try to co-locate the blocks of the i-th dataset with the i-th actor to maximize data locality.


A list of n disjoint dataset splits.