ray.data.DatasetPipeline.split_at_indices#

DatasetPipeline.split_at_indices(indices: List[int]) List[ray.data.dataset_pipeline.DatasetPipeline[ray.data.block.T]][source]#

Split the datasets within the pipeline at the given indices (like np.split).

This will split each dataset contained within this pipeline, thereby producing len(indices) + 1 pipelines with the first pipeline containing the [0, indices[0]) slice from each dataset, the second pipeline containing the [indices[0], indices[1]) slice from each dataset, and so on, with the final pipeline will containing the [indices[-1], self.count()) slice from each dataset.

Examples

>>> import ray
>>> p1, p2, p3 = ray.data.range( 
...     8).repeat(2).split_at_indices([2, 5]) 
>>> p1.take() 
[0, 1, 0, 1]
>>> p2.take() 
[2, 3, 4, 2, 3, 4]
>>> p3.take() 
[5, 6, 7, 5, 6, 7]

Time complexity: O(num splits)

See also: DatasetPipeline.split

Parameters

indices – List of sorted integers which indicate where the pipeline will be split. If an index exceeds the length of the pipeline, an empty pipeline will be returned.

Returns

The pipeline splits.