ray.data.Dataset.streaming_train_test_split#

Dataset.streaming_train_test_split(test_size: float, *, split_type: Literal['hash', 'random'] = 'random', hash_column: str | None = None, seed: int | None = None, **ray_remote_kwargs) Tuple[Dataset, Dataset][source]#

split the dataset into train and test subsets in a streaming manner. This method is recommended for large datasets.

The split type can be either “hash” or “random”. - “random”: The dataset is split into random train and test subsets. - “hash”: The dataset is split into train and test subsets based on the hash of the key column.

Tip

Make sure to set the preserve_order flag in the ExecutionOptions to True to ensure that the split is deterministic across pipeline executions. This is important to avoid test rows to end up in the train set and vice versa on multiple executions. This can be set with ray.data.DataContext.get_current().execution_options.preserve_order = True.

Examples

Examples with Random split:

>>> import ray
>>> ctx = ray.data.DataContext.get_current()
>>> ctx.execution_options.preserve_order = True
>>> ds = ray.data.range(8)
>>> train, test = ds.streaming_train_test_split(test_size=0.25, seed=0)
>>> train.count()
6
>>> test.count()
2
>>> ctx.execution_options.preserve_order = False

Examples with Hash split:

>>> import ray
>>> ds = ray.data.range(8)
>>> train, test = ds.streaming_train_test_split(test_size=0.25, split_type="hash", hash_column="id")
>>> train.take_batch()
{'id': array([0, 2, 3, 4, 5, 6])}
>>> test.take_batch()
{'id': array([1, 7])}
Parameters:
  • test_size – The proportion of the dataset to include in the test split. Must be between 0.0 and 1.0.

  • split_type – The type of split to perform. Can be “hash” or “random”.

  • hash_column – The column to use for the hash split. Required for hash split and ignored for random split.

  • seed – The seed to use for the random split. Ignored for hash split.

  • **ray_remote_kwargs – Additional kwargs to pass to the Ray remote function.

Returns:

Train and test subsets as two Dataset.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.