Source code for umongo.frameworks.pymongo

import collections
from contextlib import contextmanager
from contextvars import ContextVar

import marshmallow as ma

from pymongo.cursor import Cursor
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError

from umongo.builder import BaseBuilder
from umongo.data_objects import Reference
from umongo.document import DocumentImplementation
from umongo.exceptions import (
    DeleteError,
    NoneReferenceError,
    NotCreatedError,
    UpdateError,
)
from umongo.fields import DictField, EmbeddedField, ListField, ReferenceField
from umongo.instance import Instance
from umongo.query_mapper import map_query

from .tools import (
    cook_find_filter,
    cook_find_projection,
    remove_cls_field_from_embedded_docs,
)

SESSION = ContextVar("session", default=None)


# pymongo.Cursor defines __del__ method, hence mongomock's WrappedCursor should
# not inherit from this class otherwise garbage collection will crash...
class BaseWrappedCursor:
    __slots__ = ("document_cls", "raw_cursor")

    def __init__(self, document_cls, cursor, *args, **kwargs):
        # Such a cunning plan my lord !
        # We inherit from Cursor but don't call its __init__ because
        # we act as a proxy to the underlying raw_cursor
        WrappedCursor.raw_cursor.__set__(self, cursor)
        WrappedCursor.document_cls.__set__(self, document_cls)

    def __getattr__(self, name):
        return getattr(self.raw_cursor, name)

    def __setattr__(self, name, value):
        return setattr(self.raw_cursor, name, value)

    def __getitem__(self, index):
        if isinstance(index, slice):
            elems = self.raw_cursor[index]
            return (
                self.document_cls.build_from_mongo(elem, use_cls=True) for elem in elems
            )
        elem = self.raw_cursor[index]
        return self.document_cls.build_from_mongo(elem, use_cls=True)

    def __next__(self):
        elem = next(self.raw_cursor)
        return self.document_cls.build_from_mongo(elem, use_cls=True)

    def __iter__(self):
        for elem in self.raw_cursor:
            yield self.document_cls.build_from_mongo(elem, use_cls=True)


class WrappedCursor(BaseWrappedCursor, Cursor):
    __slots__ = ()


class PyMongoDocument(DocumentImplementation):
    __slots__ = ()
    cursor_cls = WrappedCursor  # Easier to customize this for mongomock this way

    opts = DocumentImplementation.opts

    def reload(self):
        """Retrieve and replace document's data by the ones in database.

        Raises :class:`umongo.exceptions.NotCreatedError` if the document
        doesn't exist in database.
        """
        if not self.is_created:
            raise NotCreatedError("Document doesn't exists in database")
        ret = self.collection.find_one(self.pk, session=SESSION.get())
        if ret is None:
            raise NotCreatedError("Document doesn't exists in database")
        self._data = self.DataProxy()
        self._data.from_mongo(ret)

    def commit(self, io_validate_all=False, conditions=None, replace=False):
        """Commit the document in database.
        If the document doesn't already exist it will be inserted, otherwise
        it will be updated.

        :param io_validate_all: Validate all field instead of only changed ones.
        :param conditions: Only perform commit if matching record in db
            satisfies condition(s) (e.g. version number).
            Raises :class:`umongo.exceptions.UpdateError` if the
            conditions are not satisfied.
        :param replace: Replace the document rather than update.
        :return: A :class:`pymongo.results.UpdateResult` or
            :class:`pymongo.results.InsertOneResult` depending of the operation.
        """
        try:
            if self.is_created:
                if self.is_modified() or replace:
                    query = conditions or {}
                    query["_id"] = self.pk
                    # pre_update can provide additional query filter and/or
                    # modify the fields' values
                    additional_filter = self.pre_update()
                    if additional_filter:
                        query.update(map_query(additional_filter, self.schema.fields))
                    self.required_validate()
                    self.io_validate(validate_all=io_validate_all)
                    if replace:
                        payload = self._data.to_mongo(update=False)
                        ret = self.collection.replace_one(
                            query,
                            payload,
                            session=SESSION.get(),
                        )
                    else:
                        payload = self._data.to_mongo(update=True)
                        ret = self.collection.update_one(
                            query,
                            payload,
                            session=SESSION.get(),
                        )
                    if ret.matched_count != 1:
                        raise UpdateError(ret)
                    self.post_update(ret)
                else:
                    ret = None
            elif conditions:
                raise NotCreatedError(
                    "Document must already exist in database to use `conditions`.",
                )
            else:
                self.pre_insert()
                self.required_validate()
                self.io_validate(validate_all=io_validate_all)
                payload = self._data.to_mongo(update=False)
                ret = self.collection.insert_one(payload, session=SESSION.get())
                # TODO: check ret ?
                self._data.set(self.pk_field, ret.inserted_id)
                self.is_created = True
                self.post_insert(ret)
        except DuplicateKeyError as exc:
            # Sort value to make testing easier for compound indexes
            keys = sorted(exc.details["keyPattern"].keys())
            try:
                fields = [self.schema.fields[k] for k in keys]
            except KeyError:
                # A key in the index is unknown from umongo
                raise exc from None
            if len(keys) == 1:
                msg = fields[0].error_messages["unique"]
                raise ma.ValidationError({keys[0]: msg}) from exc
            raise ma.ValidationError(
                {
                    k: f.error_messages["unique_compound"].format(fields=keys)
                    for k, f in zip(keys, fields, strict=True)
                },
            ) from exc
        self._data.clear_modified()
        return ret

    def delete(self, conditions=None):
        """Remove the document from database.

        :param conditions: Only perform delete if matching record in db
            satisfies condition(s) (e.g. version number).
            Raises :class:`umongo.exceptions.DeleteError` if the
            conditions are not satisfied.
        Raises :class:`umongo.exceptions.NotCreatedError` if the document
        is not created (i.e. ``doc.is_created`` is False)
        Raises :class:`umongo.exceptions.DeleteError` if the document
        doesn't exist in database.

        :return: A :class:`pymongo.results.DeleteResult`
        """
        if not self.is_created:
            raise NotCreatedError("Document doesn't exists in database")
        query = conditions or {}
        query["_id"] = self.pk
        # pre_delete can provide additional query filter
        additional_filter = self.pre_delete()
        if additional_filter:
            query.update(map_query(additional_filter, self.schema.fields))
        ret = self.collection.delete_one(query, session=SESSION.get())
        if ret.deleted_count != 1:
            raise DeleteError(ret)
        self.is_created = False
        self.post_delete(ret)
        return ret

    def io_validate(self, validate_all=False):
        """Run the io_validators of the document's fields.

        :param validate_all: If False only run the io_validators of the
            fields that have been modified.
        """
        if validate_all:
            _io_validate_data_proxy(self.schema, self._data)
        else:
            _io_validate_data_proxy(
                self.schema,
                self._data,
                partial=self._data.get_modified_fields(),
            )

    @classmethod
    def find_one(cls, filter=None, projection=None, **kwargs):
        """Find a single document in database."""
        filter = cook_find_filter(cls, filter)
        if projection:
            projection = cook_find_projection(cls, projection)
        ret = cls.collection.find_one(
            filter,
            projection=projection,
            session=SESSION.get(),
            **kwargs,
        )
        if ret is not None:
            ret = cls.build_from_mongo(ret, use_cls=True)
        return ret

    @classmethod
    def find(cls, filter=None, *args, **kwargs):
        """Find a list document in database.

        Returns a cursor that provide Documents.
        """
        filter = cook_find_filter(cls, filter)
        raw_cursor = cls.collection.find(filter, *args, session=SESSION.get(), **kwargs)
        return cls.cursor_cls(cls, raw_cursor)

    @classmethod
    def count_documents(cls, filter=None, **kwargs):
        """Get the number of documents in this collection.

        Unlike pymongo's collection.count_documents, filter is optional and
        defaults to an empty filter.
        """
        filter = cook_find_filter(cls, filter or {})
        return cls.collection.count_documents(filter, session=SESSION.get(), **kwargs)

    @classmethod
    def ensure_indexes(cls):
        """Check&create if needed the Document's indexes in database"""
        if cls.indexes:
            cls.collection.create_indexes(cls.indexes, session=SESSION.get())


# Run multiple validators and collect all errors in one
def _run_validators(validators, field, value):
    if not hasattr(validators, "__iter__"):
        validators(field, value)
    else:
        errors = []
        for validator in validators:
            try:
                validator(field, value)
            except ma.ValidationError as exc:
                errors.extend(exc.messages)
        if errors:
            raise ma.ValidationError(errors)


def _io_validate_data_proxy(schema, data_proxy, partial=None):
    errors = {}
    for name, field in schema.fields.items():
        if partial and name not in partial:
            continue
        value = data_proxy.get(name)
        if value is ma.missing:
            continue
        try:
            if field.io_validate_recursive:
                field.io_validate_recursive(field, value)
            if field.io_validate:
                _run_validators(field.io_validate, field, value)
        except ma.ValidationError as exc:
            errors[name] = exc.messages
    if errors:
        raise ma.ValidationError(errors)


def _reference_io_validate(field, value):
    if value is None:
        return
    if not value.exists:
        raise ma.ValidationError(
            value.error_messages["not_found"].format(
                document=value.document_cls.__name__,
            ),
        )


def _list_io_validate(field, value):
    if not value:
        return
    errors = {}
    validators = field.inner.io_validate
    if not validators:
        return
    for idx, val in enumerate(value):
        try:
            _run_validators(validators, field.inner, val)
        except ma.ValidationError as exc:
            errors[idx] = exc.messages
    if errors:
        raise ma.ValidationError(errors)


def _dict_io_validate(field, value):
    if not value or not field.value_field:
        return
    errors = collections.defaultdict(dict)
    validators = field.value_field.io_validate
    if not validators:
        return
    for key, val in value.items():
        try:
            _run_validators(validators, field.value_field, val)
        except ma.ValidationError as exc:
            errors[key]["value"] = exc.messages
    if errors:
        raise ma.ValidationError(errors)


def _embedded_document_io_validate(field, value):
    if not value:
        return
    _io_validate_data_proxy(value.schema, value._data)


class PyMongoReference(Reference):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._document = None

    def fetch(self, no_data=False, force_reload=False, projection=None):
        if not self._document or force_reload:
            if self.pk is None:
                raise NoneReferenceError("Cannot retrieve a None Reference")
            self._document = self.document_cls.find_one(self.pk, projection=projection)
            if not self._document:
                raise ma.ValidationError(
                    self.error_messages["not_found"].format(
                        document=self.document_cls.__name__,
                    ),
                )
        return self._document

    @property
    def exists(self):
        return (
            self.document_cls.collection.find_one(self.pk, projection={"_id": True})
            is not None
        )


class PyMongoBuilder(BaseBuilder):
    BASE_DOCUMENT_CLS = PyMongoDocument

    def _patch_field(self, field):
        super()._patch_field(field)

        validators = field.io_validate
        if not validators:
            field.io_validate = []
        elif hasattr(validators, "__iter__"):
            field.io_validate = list(validators)
        else:
            field.io_validate = [validators]
        if isinstance(field, ListField):
            field.io_validate_recursive = _list_io_validate
        if isinstance(field, DictField):
            field.io_validate_recursive = _dict_io_validate
        if isinstance(field, ReferenceField):
            field.io_validate.append(_reference_io_validate)
            field.reference_cls = PyMongoReference
        if isinstance(field, EmbeddedField):
            field.io_validate_recursive = _embedded_document_io_validate


[docs] class PyMongoInstance(Instance): """:class:`umongo.instance.Instance` implementation for pymongo""" BUILDER_CLS = PyMongoBuilder @staticmethod def is_compatible_with(db): return isinstance(db, Database) @contextmanager def session(self): with self.db.client.start_session() as session: try: token = SESSION.set(session) yield session finally: SESSION.reset(token)
class PyMongoMigrationInstance(PyMongoInstance): """PyMongo instance with migration features""" def migrate_2_to_3(self): """Migrate database from umongo 2 to umongo 3 - EmbeddedDocument _cls field is only set if child of concrete embedded document """ concrete_not_children = [ name for name, ed in self._embedded_lookup.items() if not ed.opts.is_child and not ed.opts.abstract ] for doc_cls in self._doc_lookup.values(): if doc_cls.opts.abstract: continue if doc_cls.opts.is_child: continue for doc in doc_cls.collection.find(): doc = remove_cls_field_from_embedded_docs(doc, concrete_not_children) ret = doc_cls.collection.replace_one({"_id": doc["_id"]}, doc) if ret.matched_count != 1: raise UpdateError(ret)