Custom Datasources

Ray Datasets supports multiple ways to create a dataset, allowing you to easily ingest data of common formats from popular sources. However, if the datasource you want to read from is not in the built-in list, don’t worry, you can implement a custom one for your use case. In this guide, we will walk you through how to build your own custom datasource, using MongoDB as an example. By the end of the guide, you will have a MongoDatasource that you can use to create dataset as follows:

# Read from custom MongoDB datasource to create a dataset.
ds = ray.data.read_datasource(
    MongoDatasource(),
    uri=MY_URI,
    database=MY_DATABASE,
    collection=MY_COLLECTION,
    pipelines=MY_PIPELINES
)

# Write the dataset to custom MongoDB datasource.
ds.write_datasource(
    MongoDatasource(), uri=MY_URI, database=MY_DATABASE, collection=MY_COLLECTION
)

Tip

There are a few MongoDB concepts involved here. The URI points to a MongoDB instance, which hosts Databases and Collections. A collection is analogous to a table in SQL databases. MongoDB also has a pipeline concept, which expresses document processing in a series of stages (e.g. match documents with a predicate, sort results, and then select a few fields). The execution results of the pipelines are used to create dataset.

A custom datasource is an implementation of Datasource. In the example here, let’s call it MongoDatasource. At a high level, it will have two core parts to build out:

Here are the key design choices we will make in this guide:

  • MongoDB connector: We use PyMongo to connect to MongoDB.

  • MongoDB to Arrow conversion: We use PyMongoArrow to convert MongoDB execution results into Arrow tables, which Datasets supports as a data format.

  • Parallel execution: We ask the user to provide a list of MongoDB pipelines, with each corresponding to a partition of the MongoDB collection, which will be executed in parallel with ReadTask.

For example, suppose you have a MongoDB collection with 4 documents, which have a partition_field with values 0, 1, 2, 3. You can compose two MongoDB pipelines (each handled by a ReadTask) as follows to read the collection in parallel:

# A list of pipelines. Each pipeline is a series of stages, typed as List[Dict].
my_pipelines = [
    # The first pipeline: match documents in partition range [0, 2)
    [
      {
        "$match": {
            "partition_field": {
                "$gte": 0
                "$lt": 2
            }
        }
      }
    ],
    # The second pipeline: match documents in partition range [2, 4)
    [
      {
        "$match": {
            "partition_field": {
                "$gte": 2
                "$lt": 4
            }

        }
      }
    ],
]

Read support

To support reading, we implement create_reader(), returning a Reader implementation for MongoDB. This Reader creates a list of ReadTask for the given list of MongoDB pipelines. Each ReadTask returns a list of blocks when called, and each ReadTask is executed in remote workers to parallelize the execution.

You can find documentation about Ray Datasets block concept here and block APIs here.

First, let’s handle a single MongoDB pipeline, which is the unit of execution in ReadTask. We need to connect to MongoDB, execute the pipeline against it, and then convert results into Arrow format. We use PyMongo and PyMongoArrow to achieve this.

from ray.data.block import Block

# This connects to MongoDB, executes the pipeline against it, converts the result
# into Arrow format and returns the result as a Block.
def _read_single_partition(
    uri, database, collection, pipeline, schema, kwargs
) -> Block:
    import pymongo
    from pymongoarrow.api import aggregate_arrow_all

    client = pymongo.MongoClient(uri)
    # Read more about this API here:
    # https://mongo-arrow.readthedocs.io/en/stable/api/api.html#pymongoarrow.api.aggregate_arrow_all
    return aggregate_arrow_all(
        client[database][collection], pipeline, schema=schema, **kwargs
    )

Once we have this building block, we can just apply it for each of the provided MongoDB pipelines. In particular, below, we construct a _MongoDatasourceReader by subclassing Reader, and implement the __init__ and get_read_tasks.

In __init__, we pass in a couple arguments that will be eventually used in constructing the MongoDB pipeline in _read_single_partition.

In get_read_tasks, we construct a ReadTask object for each pipeline object. This will need to provide a BlockMetadata and a no-arg read function as arguments. The BlockMetadata contains metadata like number of rows, size in bytes and schema that we know about the block prior to actually executing the read task; the no-arg read function is just a wrapper of _read_single_partition. A list of ReadTask objects are returned by get_read_tasks, and these tasks are executed on remote workers. You can find more details about Dataset read execution here.

from typing import Any, Dict, List, Optional
from ray.data.datasource.datasource import Datasource, Reader, ReadTask
from ray.data.block import BlockMetadata

class _MongoDatasourceReader(Reader):
    # This is constructed by the MongoDatasource, which will supply these args
    # about MongoDB.
    def __init__(self, uri, database, collection, pipelines, schema, kwargs):
        self._uri = uri
        self._database = database
        self._collection = collection
        self._pipelines = pipelines
        self._schema = schema
        self._kwargs = kwargs

    # Create a list of ``ReadTask``, one for each pipeline (i.e. a partition of
    # the MongoDB collection). Those tasks will be executed in parallel.
    # Note: The ``parallelism`` which is supposed to indicate how many ``ReadTask`` to
    # return will have no effect here, since we map each query into a ``ReadTask``.
    def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
        read_tasks: List[ReadTask] = []
        for pipeline in self._pipelines:
            # The metadata about the block that we know prior to actually executing
            # the read task.
            metadata = BlockMetadata(
                num_rows=None,
                size_bytes=None,
                schema=self._schema,
                input_files=None,
                exec_stats=None,
            )
            # Supply a no-arg read function (which returns a block) and pre-read
            # block metadata.
            read_task = ReadTask(
                lambda uri=self._uri, database=self._database,
                       collection=self._collection, pipeline=pipeline,
                       schema=self._schema, kwargs=self._kwargs: [
                    _read_single_partition(
                        uri, database, collection, pipeline, schema, kwargs
                    )
                ],
                metadata,
            )
            read_tasks.append(read_task)
        return read_tasks

Now, we have finished implementing support for reading from a custom datasource! Let’s move on to implementing support for writing back to the custom datasource.

Write support

Similar to read support, we start with handling a single block. Again the PyMongo and PyMongoArrow are used for MongoDB interactions.

# This connects to MongoDB and writes a block into it.
# Note this is an insertion, i.e. each record in the block are treated as
# new document to the MongoDB (so no mutation of existing documents).
def _write_single_block(uri, database, collection, block: Block):
    import pymongo
    from pymongoarrow.api import write

    client = pymongo.MongoClient(uri)
    # Read more about this API here:
    # https://mongo-arrow.readthedocs.io/en/stable/api/api.html#pymongoarrow.api.write
    write(client[database][collection], block)

Unlike read support, we do not need to implement a custom interface.

Below, we implement a helper function to parallelize writing, which is expected to return a list of Ray ObjectRefs. This helper function will later be used in the implementation of do_write().

In short, the below function spawns multiple Ray remote tasks and returns their futures (object refs).

from ray.data._internal.remote_fn import cached_remote_fn
from ray.types import ObjectRef
from ray.data.datasource.datasource import WriteResult

# This writes a list of blocks into MongoDB. Each block is handled by a task and
# tasks are executed in parallel.
def _write_multiple_blocks(
    blocks: List[ObjectRef[Block]],
    metadata: List[BlockMetadata],
    ray_remote_args: Optional[Dict[str, Any]],
    uri,
    database,
    collection,
) -> List[ObjectRef[WriteResult]]:
    # The ``cached_remote_fn`` turns the ``_write_single_block`` into a Ray
    # remote function.
    write_block = cached_remote_fn(_write_single_block).options(**ray_remote_args)
    write_tasks = []
    for block in blocks:
        # Create a Ray remote function for each block.
        write_task = write_block.remote(uri, database, collection, block)
        write_tasks.append(write_task)
    return write_tasks

Putting it all together

With _MongoDatasourceReader and _write_multiple_blocks above, we are ready to implement create_reader() and do_write(), and put together a MongoDatasource.

# MongoDB datasource, for reading from and writing to MongoDB.
class MongoDatasource(Datasource):
    def create_reader(
        self, uri, database, collection, pipelines, schema, kwargs
    ) -> Reader:
        return _MongoDatasourceReader(
            uri, database, collection, pipelines, schema, kwargs
        )

    def do_write(
        self,
        blocks: List[ObjectRef[Block]],
        metadata: List[BlockMetadata],
        ray_remote_args: Optional[Dict[str, Any]],
        uri,
        database,
        collection,
    ) -> List[ObjectRef[WriteResult]]:
        return _write_multiple_blocks(
            blocks, metadata, ray_remote_args, uri, database, collection
        )

Now you can create a Ray Dataset from and write back to MongoDB, just like any other datasource!

# Read from MongoDB datasource and create a dataset.
# The args are passed to MongoDatasource.create_reader().
ds = ray.data.read_datasource(
    MongoDatasource(),
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection=="my_collection",
    pipelines=my_pipelines, # See the example definition of ``my_pipelines`` above
)

# Data preprocessing with Dataset APIs here
# ...

# Write the dataset back to MongoDB datasource.
# The args are passed to MongoDatasource.do_write().
ds.write_datasource(
    MongoDatasource(),
    uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
    database="my_db",
    collection="my_collection"
)