ray.data.Dataset.split_at_indices#

Dataset.split_at_indices(indices: List[int]) List[ray.data.dataset.Dataset[ray.data.block.T]][source]#

Split the dataset at the given indices (like np.split).

Examples

>>> import ray
>>> ds = ray.data.range(10)
>>> d1, d2, d3 = ds.split_at_indices([2, 5])
>>> d1.take()
[0, 1]
>>> d2.take()
[2, 3, 4]
>>> d3.take()
[5, 6, 7, 8, 9]

Time complexity: O(num splits)

See also: Dataset.split, Dataset.split_proportionately

Parameters

indices – List of sorted integers which indicate where the dataset will be split. If an index exceeds the length of the dataset, an empty dataset will be returned.

Returns

The dataset splits.