ray.data.Dataset.filter#

Dataset.filter(fn: Callable[[Dict[str, Any]], bool] | Callable[[Dict[str, Any]], Iterator[bool]] | _CallableClassProtocol, *, compute: str | ComputeStrategy = None, concurrency: int | Tuple[int, int] | None = None, **ray_remote_args) Dataset[source]#

Filter out rows that don’t satisfy the given predicate.

You can use either a function or a callable class to perform the transformation. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses stateful Ray actors. For more information, see Stateful Transforms.

Tip

If you can represent your predicate with NumPy or pandas operations, Dataset.map_batches() might be faster. You can implement filter by dropping rows.

Examples

>>> import ray
>>> ds = ray.data.range(100)
>>> ds.filter(lambda row: row["id"] % 2 == 0).take_all()
[{'id': 0}, {'id': 2}, {'id': 4}, ...]

Time complexity: O(dataset size / parallelism)

Parameters:
  • fn – The predicate to apply to each row, or a class type that can be instantiated to create such a callable.

  • compute – This argument is deprecated. Use concurrency argument.

  • concurrency – The number of Ray workers to use concurrently. For a fixed-sized worker pool of size n, specify concurrency=n. For an autoscaling worker pool from m to n workers, specify concurrency=(m, n).

  • ray_remote_args – Additional resource requirements to request from ray (e.g., num_gpus=1 to request GPUs for the map tasks).