import collections
from contextvars import ContextVar
from contextlib import contextmanager
from pymongo.database import Database
from pymongo.cursor import Cursor
from pymongo.errors import DuplicateKeyError
import marshmallow as ma
from ..builder import BaseBuilder
from ..instance import Instance
from ..document import DocumentImplementation
from ..data_objects import Reference
from ..exceptions import NotCreatedError, UpdateError, DeleteError, NoneReferenceError
from ..fields import ReferenceField, ListField, DictField, EmbeddedField
from ..query_mapper import map_query
from .tools import cook_find_filter, cook_find_projection, remove_cls_field_from_embedded_docs
SESSION = ContextVar("session", default=None)
# pymongo.Cursor defines __del__ method, hence mongomock's WrappedCursor should
# not inherit from this class otherwise garbage collection will crash...
class BaseWrappedCursor:
__slots__ = ('raw_cursor', 'document_cls')
def __init__(self, document_cls, cursor, *args, **kwargs):
# Such a cunning plan my lord !
# We inherit from Cursor but don't call its __init__ because
# we act as a proxy to the underlying raw_cursor
WrappedCursor.raw_cursor.__set__(self, cursor)
WrappedCursor.document_cls.__set__(self, document_cls)
def __getattr__(self, name):
return getattr(self.raw_cursor, name)
def __setattr__(self, name, value):
return setattr(self.raw_cursor, name, value)
def __getitem__(self, index):
if isinstance(index, slice):
elems = self.raw_cursor[index]
return (self.document_cls.build_from_mongo(elem, use_cls=True)
for elem in elems)
elem = self.raw_cursor[index]
return self.document_cls.build_from_mongo(elem, use_cls=True)
def __next__(self):
elem = next(self.raw_cursor)
return self.document_cls.build_from_mongo(elem, use_cls=True)
def __iter__(self):
for elem in self.raw_cursor:
yield self.document_cls.build_from_mongo(elem, use_cls=True)
class WrappedCursor(BaseWrappedCursor, Cursor):
__slots__ = ()
class PyMongoDocument(DocumentImplementation):
__slots__ = ()
cursor_cls = WrappedCursor # Easier to customize this for mongomock this way
opts = DocumentImplementation.opts
def reload(self):
"""
Retrieve and replace document's data by the ones in database.
Raises :class:`umongo.exceptions.NotCreatedError` if the document
doesn't exist in database.
"""
if not self.is_created:
raise NotCreatedError("Document doesn't exists in database")
ret = self.collection.find_one(self.pk, session=SESSION.get())
if ret is None:
raise NotCreatedError("Document doesn't exists in database")
self._data = self.DataProxy()
self._data.from_mongo(ret)
def commit(self, io_validate_all=False, conditions=None, replace=False):
"""
Commit the document in database.
If the document doesn't already exist it will be inserted, otherwise
it will be updated.
:param io_validate_all: Validate all field instead of only changed ones.
:param conditions: Only perform commit if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.UpdateError` if the
conditions are not satisfied.
:param replace: Replace the document rather than update.
:return: A :class:`pymongo.results.UpdateResult` or
:class:`pymongo.results.InsertOneResult` depending of the operation.
"""
try:
if self.is_created:
if self.is_modified() or replace:
query = conditions or {}
query['_id'] = self.pk
# pre_update can provide additional query filter and/or
# modify the fields' values
additional_filter = self.pre_update()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
self.required_validate()
self.io_validate(validate_all=io_validate_all)
if replace:
payload = self._data.to_mongo(update=False)
ret = self.collection.replace_one(query, payload, session=SESSION.get())
else:
payload = self._data.to_mongo(update=True)
ret = self.collection.update_one(query, payload, session=SESSION.get())
if ret.matched_count != 1:
raise UpdateError(ret)
self.post_update(ret)
else:
ret = None
elif conditions:
raise NotCreatedError(
'Document must already exist in database to use `conditions`.'
)
else:
self.pre_insert()
self.required_validate()
self.io_validate(validate_all=io_validate_all)
payload = self._data.to_mongo(update=False)
ret = self.collection.insert_one(payload, session=SESSION.get())
# TODO: check ret ?
self._data.set(self.pk_field, ret.inserted_id)
self.is_created = True
self.post_insert(ret)
except DuplicateKeyError as exc:
# Sort value to make testing easier for compound indexes
keys = sorted(exc.details['keyPattern'].keys())
try:
fields = [self.schema.fields[k] for k in keys]
except KeyError:
# A key in the index is unknwon from umongo
raise exc
if len(keys) == 1:
msg = fields[0].error_messages['unique']
raise ma.ValidationError({keys[0]: msg})
raise ma.ValidationError({
k: f.error_messages['unique_compound'].format(fields=keys)
for k, f in zip(keys, fields)
})
self._data.clear_modified()
return ret
def delete(self, conditions=None):
"""
Remove the document from database.
:param conditions: Only perform delete if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.DeleteError` if the
conditions are not satisfied.
Raises :class:`umongo.exceptions.NotCreatedError` if the document
is not created (i.e. ``doc.is_created`` is False)
Raises :class:`umongo.exceptions.DeleteError` if the document
doesn't exist in database.
:return: A :class:`pymongo.results.DeleteResult`
"""
if not self.is_created:
raise NotCreatedError("Document doesn't exists in database")
query = conditions or {}
query['_id'] = self.pk
# pre_delete can provide additional query filter
additional_filter = self.pre_delete()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
ret = self.collection.delete_one(query, session=SESSION.get())
if ret.deleted_count != 1:
raise DeleteError(ret)
self.is_created = False
self.post_delete(ret)
return ret
def io_validate(self, validate_all=False):
"""
Run the io_validators of the document's fields.
:param validate_all: If False only run the io_validators of the
fields that have been modified.
"""
if validate_all:
_io_validate_data_proxy(self.schema, self._data)
else:
_io_validate_data_proxy(
self.schema, self._data, partial=self._data.get_modified_fields())
@classmethod
def find_one(cls, filter=None, projection=None, *args, **kwargs):
"""
Find a single document in database.
"""
filter = cook_find_filter(cls, filter)
if projection:
projection = cook_find_projection(cls, projection)
ret = cls.collection.find_one(filter, projection=projection,
session=SESSION.get(), *args, **kwargs)
if ret is not None:
ret = cls.build_from_mongo(ret, use_cls=True)
return ret
@classmethod
def find(cls, filter=None, *args, **kwargs):
"""
Find a list document in database.
Returns a cursor that provide Documents.
"""
filter = cook_find_filter(cls, filter)
raw_cursor = cls.collection.find(filter, session=SESSION.get(), *args, **kwargs)
return cls.cursor_cls(cls, raw_cursor)
@classmethod
def count_documents(cls, filter=None, **kwargs):
"""
Get the number of documents in this collection.
Unlike pymongo's collection.count_documents, filter is optional and
defaults to an empty filter.
"""
filter = cook_find_filter(cls, filter or {})
return cls.collection.count_documents(filter, session=SESSION.get(), **kwargs)
@classmethod
def ensure_indexes(cls):
"""
Check&create if needed the Document's indexes in database
"""
if cls.indexes:
cls.collection.create_indexes(cls.indexes, session=SESSION.get())
# Run multiple validators and collect all errors in one
def _run_validators(validators, field, value):
if not hasattr(validators, '__iter__'):
validators(field, value)
else:
errors = []
for validator in validators:
try:
validator(field, value)
except ma.ValidationError as exc:
errors.extend(exc.messages)
if errors:
raise ma.ValidationError(errors)
def _io_validate_data_proxy(schema, data_proxy, partial=None):
errors = {}
for name, field in schema.fields.items():
if partial and name not in partial:
continue
value = data_proxy.get(name)
if value is ma.missing:
continue
try:
if field.io_validate_recursive:
field.io_validate_recursive(field, value)
if field.io_validate:
_run_validators(field.io_validate, field, value)
except ma.ValidationError as exc:
errors[name] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _reference_io_validate(field, value):
if value is None:
return
if not value.exists:
raise ma.ValidationError(value.error_messages['not_found'].format(
document=value.document_cls.__name__))
def _list_io_validate(field, value):
if not value:
return
errors = {}
validators = field.inner.io_validate
if not validators:
return
for idx, val in enumerate(value):
try:
_run_validators(validators, field.inner, val)
except ma.ValidationError as exc:
errors[idx] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _dict_io_validate(field, value):
if not value or not field.value_field:
return
errors = collections.defaultdict(dict)
validators = field.value_field.io_validate
if not validators:
return
for key, val in value.items():
try:
_run_validators(validators, field.value_field, val)
except ma.ValidationError as exc:
errors[key]["value"] = exc.messages
if errors:
raise ma.ValidationError(errors)
def _embedded_document_io_validate(field, value):
if not value:
return
_io_validate_data_proxy(value.schema, value._data)
class PyMongoReference(Reference):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._document = None
def fetch(self, no_data=False, force_reload=False, projection=None):
if not self._document or force_reload:
if self.pk is None:
raise NoneReferenceError('Cannot retrieve a None Reference')
self._document = self.document_cls.find_one(self.pk, projection=projection)
if not self._document:
raise ma.ValidationError(self.error_messages['not_found'].format(
document=self.document_cls.__name__))
return self._document
@property
def exists(self):
return self.document_cls.collection.find_one(self.pk, projection={'_id': True}) is not None
class PyMongoBuilder(BaseBuilder):
BASE_DOCUMENT_CLS = PyMongoDocument
def _patch_field(self, field):
super()._patch_field(field)
validators = field.io_validate
if not validators:
field.io_validate = []
else:
if hasattr(validators, '__iter__'):
field.io_validate = list(validators)
else:
field.io_validate = [validators]
if isinstance(field, ListField):
field.io_validate_recursive = _list_io_validate
if isinstance(field, DictField):
field.io_validate_recursive = _dict_io_validate
if isinstance(field, ReferenceField):
field.io_validate.append(_reference_io_validate)
field.reference_cls = PyMongoReference
if isinstance(field, EmbeddedField):
field.io_validate_recursive = _embedded_document_io_validate
[docs]class PyMongoInstance(Instance):
"""
:class:`umongo.instance.Instance` implementation for pymongo
"""
BUILDER_CLS = PyMongoBuilder
@staticmethod
def is_compatible_with(db):
return isinstance(db, Database)
@contextmanager
def session(self):
with self.db.client.start_session() as session:
try:
token = SESSION.set(session)
yield session
finally:
SESSION.reset(token)
class PyMongoMigrationInstance(PyMongoInstance):
"""PyMongo instance with migration features"""
def migrate_2_to_3(self):
"""Migrate database from umongo 2 to umongo 3
- EmbeddedDocument _cls field is only set if child of concrete embedded document
"""
concrete_not_children = [
name for name, ed in self._embedded_lookup.items()
if not ed.opts.is_child and not ed.opts.abstract
]
for doc_cls in self._doc_lookup.values():
if doc_cls.opts.abstract:
continue
if doc_cls.opts.is_child:
continue
for doc in doc_cls.collection.find():
doc = remove_cls_field_from_embedded_docs(doc, concrete_not_children)
ret = doc_cls.collection.replace_one({"_id": doc["_id"]}, doc)
if ret.matched_count != 1:
raise UpdateError(ret)