import logging
from typing import Dict, List, Optional, TYPE_CHECKING

from import Datasource, Reader, ReadTask, WriteResult
from import (
from import DelegatingBlockBuilder
from import TaskContext
from ray.util.annotations import PublicAPI
from typing import Iterable

    import pymongoarrow.api

logger = logging.getLogger(__name__)

[docs]@PublicAPI(stability="alpha") class MongoDatasource(Datasource): """Datasource for reading from and writing to MongoDB. Examples: >>> import ray >>> from import MongoDatasource >>> from pymongoarrow.api import Schema # doctest: +SKIP >>> ds = # doctest: +SKIP ... MongoDatasource(), # doctest: +SKIP ... uri="mongodb://", # noqa: E501 # doctest: +SKIP ... database="my_db", # doctest: +SKIP ... collection="my_collection", # doctest: +SKIP ... schema=Schema({"col1": pa.string(), "col2": pa.int64()}), # doctest: +SKIP ... ) # doctest: +SKIP """ def create_reader(self, **kwargs) -> Reader: return _MongoDatasourceReader(**kwargs) def write( self, blocks: Iterable[Block], ctx: TaskContext, uri: str, database: str, collection: str, ) -> WriteResult: import pymongo _validate_database_collection_exist( pymongo.MongoClient(uri), database, collection ) def write_block(uri: str, database: str, collection: str, block: Block): from pymongoarrow.api import write block = BlockAccessor.for_block(block).to_arrow() client = pymongo.MongoClient(uri) write(client[database][collection], block) builder = DelegatingBlockBuilder() for block in blocks: builder.add_block(block) block = write_block(uri, database, collection, block) # TODO: decide if we want to return richer object when the task # succeeds. return "ok"
class _MongoDatasourceReader(Reader): def __init__( self, uri: str, database: str, collection: str, pipeline: Optional[List[Dict]] = None, schema: Optional["pymongoarrow.api.Schema"] = None, **mongo_args, ): import pymongo self._uri = uri self._database = database self._collection = collection self._pipeline = pipeline self._schema = schema self._mongo_args = mongo_args # If pipeline is unspecified, read the entire collection. if not pipeline: self._pipeline = [{"$match": {"_id": {"$exists": "true"}}}] self._client = pymongo.MongoClient(uri) _validate_database_collection_exist(self._client, database, collection) self._avg_obj_size = self._client[database].command("collstats", collection)[ "avgObjSize" ] def estimate_inmemory_data_size(self) -> Optional[int]: # TODO(jian): Add memory size estimation to improve auto-tune of parallelism. return None def _get_match_query(self, pipeline: List[Dict]) -> Dict: if len(pipeline) == 0 or "$match" not in pipeline[0]: return {} return pipeline[0]["$match"] def get_read_tasks(self, parallelism: int) -> List[ReadTask]: from bson.objectid import ObjectId coll = self._client[self._database][self._collection] match_query = self._get_match_query(self._pipeline) partitions_ids = list( coll.aggregate( [ {"$match": match_query}, {"$bucketAuto": {"groupBy": "$_id", "buckets": parallelism}}, ], allowDiskUse=True, ) ) def make_block( uri: str, database: str, collection: str, pipeline: List[Dict], min_id: ObjectId, max_id: ObjectId, right_closed: bool, schema: "pymongoarrow.api.Schema", kwargs: dict, ) -> Block: import pymongo from pymongoarrow.api import aggregate_arrow_all # A range query over the partition. match = [ { "$match": { "_id": { "$gte": min_id, "$lte" if right_closed else "$lt": max_id, } } } ] client = pymongo.MongoClient(uri) return aggregate_arrow_all( client[database][collection], match + pipeline, schema=schema, **kwargs ) read_tasks: List[ReadTask] = [] for i, partition in enumerate(partitions_ids): metadata = BlockMetadata( num_rows=partition["count"], size_bytes=partition["count"] * self._avg_obj_size, schema=None, input_files=None, exec_stats=None, ) make_block_args = ( self._uri, self._database, self._collection, self._pipeline, partition["_id"]["min"], partition["_id"]["max"], i == len(partitions_ids) - 1, self._schema, self._mongo_args, ) read_task = ReadTask( lambda args=make_block_args: [make_block(*args)], metadata, ) read_tasks.append(read_task) return read_tasks def _validate_database_collection_exist(client, database: str, collection: str): db_names = client.list_database_names() if database not in db_names: raise ValueError(f"The destination database {database} doesn't exist.") collection_names = client[database].list_collection_names() if collection not in collection_names: raise ValueError(f"The destination collection {collection} doesn't exist.")