import collections
from contextlib import contextmanager
from contextvars import ContextVar
import marshmallow as ma
from pymongo.cursor import Cursor
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError
from umongo.builder import BaseBuilder
from umongo.data_objects import Reference
from umongo.document import DocumentImplementation
from umongo.exceptions import (
DeleteError,
NoneReferenceError,
NotCreatedError,
UpdateError,
)
from umongo.fields import DictField, EmbeddedField, ListField, ReferenceField
from umongo.instance import Instance
from umongo.query_mapper import map_query
from .tools import (
cook_find_filter,
cook_find_projection,
remove_cls_field_from_embedded_docs,
)
SESSION = ContextVar("session", default=None)
# pymongo.Cursor defines __del__ method, hence mongomock's WrappedCursor should
# not inherit from this class otherwise garbage collection will crash...
class BaseWrappedCursor:
__slots__ = ("document_cls", "raw_cursor")
def __init__(self, document_cls, cursor, *args, **kwargs):
# Such a cunning plan my lord !
# We inherit from Cursor but don't call its __init__ because
# we act as a proxy to the underlying raw_cursor
WrappedCursor.raw_cursor.__set__(self, cursor)
WrappedCursor.document_cls.__set__(self, document_cls)
def __getattr__(self, name):
return getattr(self.raw_cursor, name)
def __setattr__(self, name, value):
return setattr(self.raw_cursor, name, value)
def __getitem__(self, index):
if isinstance(index, slice):
elems = self.raw_cursor[index]
return (
self.document_cls.build_from_mongo(elem, use_cls=True) for elem in elems
)
elem = self.raw_cursor[index]
return self.document_cls.build_from_mongo(elem, use_cls=True)
def __next__(self):
elem = next(self.raw_cursor)
return self.document_cls.build_from_mongo(elem, use_cls=True)
def __iter__(self):
for elem in self.raw_cursor:
yield self.document_cls.build_from_mongo(elem, use_cls=True)
class WrappedCursor(BaseWrappedCursor, Cursor):
__slots__ = ()
class PyMongoDocument(DocumentImplementation):
__slots__ = ()
cursor_cls = WrappedCursor # Easier to customize this for mongomock this way
opts = DocumentImplementation.opts
def reload(self):
"""Retrieve and replace document's data by the ones in database.
Raises :class:`umongo.exceptions.NotCreatedError` if the document
doesn't exist in database.
"""
if not self.is_created:
raise NotCreatedError("Document doesn't exists in database")
ret = self.collection.find_one(self.pk, session=SESSION.get())
if ret is None:
raise NotCreatedError("Document doesn't exists in database")
self._data = self.DataProxy()
self._data.from_mongo(ret)
def commit(self, io_validate_all=False, conditions=None, replace=False):
"""Commit the document in database.
If the document doesn't already exist it will be inserted, otherwise
it will be updated.
:param io_validate_all: Validate all field instead of only changed ones.
:param conditions: Only perform commit if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.UpdateError` if the
conditions are not satisfied.
:param replace: Replace the document rather than update.
:return: A :class:`pymongo.results.UpdateResult` or
:class:`pymongo.results.InsertOneResult` depending of the operation.
"""
try:
if self.is_created:
if self.is_modified() or replace:
query = conditions or {}
query["_id"] = self.pk
# pre_update can provide additional query filter and/or
# modify the fields' values
additional_filter = self.pre_update()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
self.required_validate()
self.io_validate(validate_all=io_validate_all)
if replace:
payload = self._data.to_mongo(update=False)
ret = self.collection.replace_one(
query,
payload,
session=SESSION.get(),
)
else:
payload = self._data.to_mongo(update=True)
ret = self.collection.update_one(
query,
payload,
session=SESSION.get(),
)
if ret.matched_count != 1:
raise UpdateError(ret)
self.post_update(ret)
else:
ret = None
elif conditions:
raise NotCreatedError(
"Document must already exist in database to use `conditions`.",
)
else:
self.pre_insert()
self.required_validate()
self.io_validate(validate_all=io_validate_all)
payload = self._data.to_mongo(update=False)
ret = self.collection.insert_one(payload, session=SESSION.get())
# TODO: check ret ?
self._data.set(self.pk_field, ret.inserted_id)
self.is_created = True
self.post_insert(ret)
except DuplicateKeyError as exc:
# Sort value to make testing easier for compound indexes
keys = sorted(exc.details["keyPattern"].keys())
try:
fields = [self.schema.fields[k] for k in keys]
except KeyError:
# A key in the index is unknown from umongo
raise exc from None
if len(keys) == 1:
msg = fields[0].error_messages["unique"]
raise ma.ValidationError({keys[0]: msg}) from exc
raise ma.ValidationError(
{
k: f.error_messages["unique_compound"].format(fields=keys)
for k, f in zip(keys, fields, strict=True)
},
) from exc
self._data.clear_modified()
return ret
def delete(self, conditions=None):
"""Remove the document from database.
:param conditions: Only perform delete if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.DeleteError` if the
conditions are not satisfied.
Raises :class:`umongo.exceptions.NotCreatedError` if the document
is not created (i.e. ``doc.is_created`` is False)
Raises :class:`umongo.exceptions.DeleteError` if the document
doesn't exist in database.
:return: A :class:`pymongo.results.DeleteResult`
"""
if not self.is_created:
raise NotCreatedError("Document doesn't exists in database")
query = conditions or {}
query["_id"] = self.pk
# pre_delete can provide additional query filter
additional_filter = self.pre_delete()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
ret = self.collection.delete_one(query, session=SESSION.get())
if ret.deleted_count != 1:
raise DeleteError(ret)
self.is_created = False
self.post_delete(ret)
return ret
def io_validate(self, validate_all=False):
"""Run the io_validators of the document's fields.
:param validate_all: If False only run the io_validators of the
fields that have been modified.
"""
if validate_all:
_io_validate_data_proxy(self.schema, self._data)
else:
_io_validate_data_proxy(
self.schema,
self._data,
partial=self._data.get_modified_fields(),
)
@classmethod
def find_one(cls, filter=None, projection=None, **kwargs):
"""Find a single document in database."""
filter = cook_find_filter(cls, filter)
if projection:
projection = cook_find_projection(cls, projection)
ret = cls.collection.find_one(
filter,
projection=projection,
session=SESSION.get(),
**kwargs,
)
if ret is not None:
ret = cls.build_from_mongo(ret, use_cls=True)
return ret
@classmethod
def find(cls, filter=None, *args, **kwargs):
"""Find a list document in database.
Returns a cursor that provide Documents.
"""
filter = cook_find_filter(cls, filter)
raw_cursor = cls.collection.find(filter, *args, session=SESSION.get(), **kwargs)
return cls.cursor_cls(cls, raw_cursor)
@classmethod
def count_documents(cls, filter=None, **kwargs):
"""Get the number of documents in this collection.
Unlike pymongo's collection.count_documents, filter is optional and
defaults to an empty filter.
"""
filter = cook_find_filter(cls, filter or {})
return cls.collection.count_documents(filter, session=SESSION.get(), **kwargs)
@classmethod
def ensure_indexes(cls):
"""Check&create if needed the Document's indexes in database"""
if cls.indexes:
cls.collection.create_indexes(cls.indexes, session=SESSION.get())
# Run multiple validators and collect all errors in one
def _run_validators(validators, field, value):
if not hasattr(validators, "__iter__"):
validators(field, value)
else:
errors = []
for validator in validators:
try:
validator(field, value)
except ma.ValidationError as exc:
errors.extend(exc.messages)
if errors:
raise ma.ValidationError(errors)
def _io_validate_data_proxy(schema, data_proxy, partial=None):
errors = {}
for name, field in schema.fields.items():
if partial and name not in partial:
continue
value = data_proxy.get(name)
if value is ma.missing:
continue
try:
if field.io_validate_recursive:
field.io_validate_recursive(field, value)
if field.io_validate:
_run_validators(field.io_validate, field, value)
except ma.ValidationError as exc:
errors[name] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _reference_io_validate(field, value):
if value is None:
return
if not value.exists:
raise ma.ValidationError(
value.error_messages["not_found"].format(
document=value.document_cls.__name__,
),
)
def _list_io_validate(field, value):
if not value:
return
errors = {}
validators = field.inner.io_validate
if not validators:
return
for idx, val in enumerate(value):
try:
_run_validators(validators, field.inner, val)
except ma.ValidationError as exc:
errors[idx] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _dict_io_validate(field, value):
if not value or not field.value_field:
return
errors = collections.defaultdict(dict)
validators = field.value_field.io_validate
if not validators:
return
for key, val in value.items():
try:
_run_validators(validators, field.value_field, val)
except ma.ValidationError as exc:
errors[key]["value"] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _embedded_document_io_validate(field, value):
if not value:
return
_io_validate_data_proxy(value.schema, value._data)
class PyMongoReference(Reference):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._document = None
def fetch(self, no_data=False, force_reload=False, projection=None):
if not self._document or force_reload:
if self.pk is None:
raise NoneReferenceError("Cannot retrieve a None Reference")
self._document = self.document_cls.find_one(self.pk, projection=projection)
if not self._document:
raise ma.ValidationError(
self.error_messages["not_found"].format(
document=self.document_cls.__name__,
),
)
return self._document
@property
def exists(self):
return (
self.document_cls.collection.find_one(self.pk, projection={"_id": True})
is not None
)
class PyMongoBuilder(BaseBuilder):
BASE_DOCUMENT_CLS = PyMongoDocument
def _patch_field(self, field):
super()._patch_field(field)
validators = field.io_validate
if not validators:
field.io_validate = []
elif hasattr(validators, "__iter__"):
field.io_validate = list(validators)
else:
field.io_validate = [validators]
if isinstance(field, ListField):
field.io_validate_recursive = _list_io_validate
if isinstance(field, DictField):
field.io_validate_recursive = _dict_io_validate
if isinstance(field, ReferenceField):
field.io_validate.append(_reference_io_validate)
field.reference_cls = PyMongoReference
if isinstance(field, EmbeddedField):
field.io_validate_recursive = _embedded_document_io_validate
[docs]
class PyMongoInstance(Instance):
""":class:`umongo.instance.Instance` implementation for pymongo"""
BUILDER_CLS = PyMongoBuilder
@staticmethod
def is_compatible_with(db):
return isinstance(db, Database)
@contextmanager
def session(self):
with self.db.client.start_session() as session:
try:
token = SESSION.set(session)
yield session
finally:
SESSION.reset(token)
class PyMongoMigrationInstance(PyMongoInstance):
"""PyMongo instance with migration features"""
def migrate_2_to_3(self):
"""Migrate database from umongo 2 to umongo 3
- EmbeddedDocument _cls field is only set if child of concrete embedded document
"""
concrete_not_children = [
name
for name, ed in self._embedded_lookup.items()
if not ed.opts.is_child and not ed.opts.abstract
]
for doc_cls in self._doc_lookup.values():
if doc_cls.opts.abstract:
continue
if doc_cls.opts.is_child:
continue
for doc in doc_cls.collection.find():
doc = remove_cls_field_from_embedded_docs(doc, concrete_not_children)
ret = doc_cls.collection.replace_one({"_id": doc["_id"]}, doc)
if ret.matched_count != 1:
raise UpdateError(ret)