ray.data.Dataset.flat_map#

Dataset.flat_map(fn: Union[Callable[[Dict[str, Any]], List[Dict[str, Any]]], Callable[[Dict[str, Any]], Iterator[List[Dict[str, Any]]]], _CallableClassProtocol], *, compute: Optional[ray.data._internal.compute.ComputeStrategy] = None, fn_constructor_args: Optional[Iterable[Any]] = None, num_cpus: Optional[float] = None, num_gpus: Optional[float] = None, **ray_remote_args) Dataset[source]#

Apply the given function to each row and then flatten results.

Use this method if your transformation returns multiple rows for each input row.

Tip

map_batches() can also modify the number of rows. If your transformation is vectorized like most NumPy and pandas operations, it might be faster.

Examples

from typing import Any, Dict, List
import ray

def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
    return [row] * 2

print(
    ray.data.range(3)
    .flat_map(duplicate_row)
    .take_all()
)
[{'id': 0}, {'id': 0}, {'id': 1}, {'id': 1}, {'id': 2}, {'id': 2}]

Time complexity: O(dataset size / parallelism)

Parameters
  • fn – The function or generator to apply to each record, or a class type that can be instantiated to create such a callable. Callable classes are only supported for the actor compute strategy.

  • compute – The compute strategy, either “tasks” (default) to use Ray tasks, ray.data.ActorPoolStrategy(size=n) to use a fixed-size actor pool, or ray.data.ActorPoolStrategy(min_size=m, max_size=n) for an autoscaling actor pool.

  • fn_constructor_args – Positional arguments to pass to fn’s constructor. You can only provide this if fn is a callable class. These arguments are top-level arguments in the underlying Ray actor construction task.

  • num_cpus – The number of CPUs to reserve for each parallel map worker.

  • num_gpus – The number of GPUs to reserve for each parallel map worker. For example, specify num_gpus=1 to request 1 GPU for each parallel map worker.

  • ray_remote_args – Additional resource requirements to request from ray for each map worker.

See also

map_batches()

Call this method to transform batches of data.

map()

Call this method to transform one row at time.