Source code for umongo.fields

"""umongo fields"""

import collections
import datetime as dt

import marshmallow as ma

from bson import DBRef, Decimal128, ObjectId

from . import marshmallow_bonus as ma_bonus_fields
from .abstract import BaseField, I18nErrorDict
from .data_objects import Dict, List, Reference
from .document import DocumentImplementation
from .exceptions import DocumentDefinitionError, NotRegisteredDocumentError
from .i18n import gettext as _
from .template import get_template

__all__ = (
    "AwareDateTimeField",
    "BoolField",
    "BooleanField",
    "ConstantField",
    "DateField",
    "DateTimeField",
    "DecimalField",
    "DictField",
    "EmailField",
    "EmbeddedField",
    "FloatField",
    "GenericReferenceField",
    "IntField",
    "IntegerField",
    "ListField",
    "NaiveDateTimeField",
    "NumberField",
    "ObjectIdField",
    "ReferenceField",
    "StrField",
    "StringField",
    "URLField",
    "UUIDField",
    "UrlField",
)


# Republish supported marshmallow fields


# class RawField(BaseField, ma.fields.Raw):
#     pass


[docs] class StringField(BaseField, ma.fields.String): pass
[docs] class UUIDField(BaseField, ma.fields.UUID): pass
[docs] class NumberField(BaseField, ma.fields.Number): pass
[docs] class IntegerField(BaseField, ma.fields.Integer): pass
[docs] class DecimalField(BaseField, ma.fields.Decimal): def _serialize_to_mongo(self, obj): return Decimal128(obj) def _deserialize_from_mongo(self, value): return value.to_decimal()
[docs] class BooleanField(BaseField, ma.fields.Boolean): pass
[docs] class FloatField(BaseField, ma.fields.Float): pass
def _round_to_millisecond(datetime): """Round a datetime to millisecond precision MongoDB stores datetimes with a millisecond precision. For consistency, use the same precision in the object representation. """ microseconds = round(datetime.microsecond, -3) if microseconds == 1000000: return datetime.replace(microsecond=0) + dt.timedelta(seconds=1) return datetime.replace(microsecond=microseconds)
[docs] class DateTimeField(BaseField, ma.fields.DateTime): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, dt.datetime): ret = value else: ret = super()._deserialize(value, attr, data, **kwargs) return _round_to_millisecond(ret)
[docs] class NaiveDateTimeField(BaseField, ma.fields.NaiveDateTime): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, dt.datetime): ret = value else: ret = super()._deserialize(value, attr, data, **kwargs) return _round_to_millisecond(ret)
[docs] class AwareDateTimeField(BaseField, ma.fields.AwareDateTime): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, dt.datetime): ret = value else: ret = super()._deserialize(value, attr, data, **kwargs) return _round_to_millisecond(ret) def _deserialize_from_mongo(self, value): value = value.replace(tzinfo=dt.timezone.utc) if self.default_timezone is not None: value = value.astimezone(self.default_timezone) return value
# class TimeField(BaseField, ma.fields.Time): # pass
[docs] class DateField(BaseField, ma.fields.Date): """This field converts a date to a datetime to store it as a BSON Date""" def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, dt.date): return value return super()._deserialize(value, attr, data) def _serialize_to_mongo(self, obj): return dt.datetime(obj.year, obj.month, obj.day) def _deserialize_from_mongo(self, value): return value.date()
# class TimeDeltaField(BaseField, ma.fields.TimeDelta): # pass
[docs] class UrlField(BaseField, ma.fields.Url): pass
[docs] class EmailField(BaseField, ma.fields.Email): pass
[docs] class ConstantField(BaseField, ma.fields.Constant): pass
[docs] class DictField(BaseField, ma.fields.Dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def cast_value_or_callable(key_field, value_field, value): if value is ma.missing: return ma.missing if callable(value): return lambda: Dict(key_field, value_field, value()) return Dict(key_field, value_field, value) self.dump_default = cast_value_or_callable( self.key_field, self.value_field, self.dump_default, ) self.load_default = cast_value_or_callable( self.key_field, self.value_field, self.load_default, ) def _deserialize(self, value, attr, data, **kwargs): value = super()._deserialize(value, attr, data, **kwargs) return Dict(self.key_field, self.value_field, value) def _serialize_to_mongo(self, obj): if obj is None: return ma.missing return { self.key_field.serialize_to_mongo(k) if self.key_field else k: self.value_field.serialize_to_mongo(v) if self.value_field else v for k, v in obj.items() } def _deserialize_from_mongo(self, value): if value: return Dict( self.key_field, self.value_field, { self.key_field.deserialize_from_mongo(k) if self.key_field else k: self.value_field.deserialize_from_mongo(v) if self.value_field else v for k, v in value.items() }, ) return Dict(self.key_field, self.value_field)
[docs] def as_marshmallow_field(self): field_kwargs = self._extract_marshmallow_field_params() if self.value_field: inner_ma_field = self.value_field.as_marshmallow_field() else: inner_ma_field = None m_field = ma.fields.Dict( self.key_field, inner_ma_field, metadata=self.metadata, **field_kwargs, ) m_field.error_messages = I18nErrorDict(m_field.error_messages) return m_field
def _required_validate(self, value): if not hasattr(self.value_field, "_required_validate"): return required_validate = self.value_field._required_validate errors = collections.defaultdict(dict) for key, val in value.items(): try: required_validate(val) except ma.ValidationError as exc: errors[key]["value"] = exc.messages if errors: raise ma.ValidationError(errors)
[docs] class ListField(BaseField, ma.fields.List): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def cast_value_or_callable(inner, value): if value is ma.missing: return ma.missing if callable(value): return lambda: List(inner, value()) return List(inner, value) self.dump_default = cast_value_or_callable(self.inner, self.dump_default) self.load_default = cast_value_or_callable(self.inner, self.load_default) def _deserialize(self, value, attr, data, **kwargs): value = super()._deserialize(value, attr, data, **kwargs) return List(self.inner, value) def _serialize_to_mongo(self, obj): if obj is None: return ma.missing return [self.inner.serialize_to_mongo(each) for each in obj] def _deserialize_from_mongo(self, value): if value: return List( self.inner, [self.inner.deserialize_from_mongo(each) for each in value], ) return List(self.inner)
[docs] def map_to_field(self, mongo_path, path, func): """Apply a function to every field in the schema""" func(mongo_path, path, self.inner) if hasattr(self.inner, "map_to_field"): self.inner.map_to_field(mongo_path, path, func)
[docs] def as_marshmallow_field(self): field_kwargs = self._extract_marshmallow_field_params() inner_ma_field = self.inner.as_marshmallow_field() m_field = ma.fields.List(inner_ma_field, metadata=self.metadata, **field_kwargs) m_field.error_messages = I18nErrorDict(m_field.error_messages) return m_field
def _required_validate(self, value): if not hasattr(self.inner, "_required_validate"): return required_validate = self.inner._required_validate errors = {} for i, sub_value in enumerate(value): try: required_validate(sub_value) except ma.ValidationError as exc: errors[i] = exc.messages if errors: raise ma.ValidationError(errors)
# Aliases URLField = UrlField StrField = StringField BoolField = BooleanField IntField = IntegerField
[docs] class ObjectIdField(BaseField, ma_bonus_fields.ObjectId): pass
[docs] class ReferenceField(BaseField, ma_bonus_fields.Reference): def __init__(self, document, *args, reference_cls=Reference, **kwargs): """:param document: Can be a :class:`umongo.embedded_document.DocumentTemplate`, another instance's :class:`umongo.embedded_document.DocumentImplementation` or the embedded document class name. .. warning:: The referenced document's _id must be an `ObjectId`. """ super().__init__(*args, **kwargs) # TODO : check document_cls is implementation or string self.reference_cls = reference_cls # Can be the Template, Template's name or another Implementation if not isinstance(document, str): self.document = get_template(document) else: self.document = document self._document_cls = None self._document_implementation_cls = DocumentImplementation @property def document_cls(self): """Return the instance's :class:`umongo.embedded_document.DocumentImplementation` implementing the `document` attribute. """ if not self._document_cls: self._document_cls = self.instance.retrieve_document(self.document) return self._document_cls def _deserialize(self, value, attr, data, **kwargs): if value is None: return None if isinstance(value, DBRef): if self.document_cls.collection.name != value.collection: raise ma.ValidationError( _("DBRef must be on collection `{collection}`.").format( self.document_cls.collection.name, ), ) value = value.id elif isinstance(value, Reference): if value.document_cls != self.document_cls: raise ma.ValidationError( _("`{document}` reference expected.").format( document=self.document_cls.__name__, ), ) if not isinstance(value, self.reference_cls): value = self.reference_cls(value.document_cls, value.pk) return value elif isinstance(value, self.document_cls): if not value.is_created: raise ma.ValidationError( _("Cannot reference a document that has not been created yet."), ) value = value.pk elif isinstance(value, self._document_implementation_cls): raise ma.ValidationError( _("`{document}` reference expected.").format( document=self.document_cls.__name__, ), ) value = super()._deserialize(value, attr, data, **kwargs) return self.reference_cls(self.document_cls, value) def _serialize_to_mongo(self, obj): return obj.pk def _deserialize_from_mongo(self, value): return self.reference_cls(self.document_cls, value)
[docs] class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference): def __init__(self, *args, reference_cls=Reference, **kwargs): super().__init__(*args, **kwargs) self.reference_cls = reference_cls self._document_implementation_cls = DocumentImplementation def _document_cls(self, class_name): try: return self.instance.retrieve_document(class_name) except NotRegisteredDocumentError as exc: raise ma.ValidationError( _("Unknown document `{document}`.").format(document=class_name), ) from exc def _serialize(self, value, attr, obj): if value is None: return None return {"id": str(value.pk), "cls": value.document_cls.__name__} def _deserialize(self, value, attr, data, **kwargs): if value is None: return None if isinstance(value, Reference): if not isinstance(value, self.reference_cls): value = self.reference_cls(value.document_cls, value.pk) return value if isinstance(value, self._document_implementation_cls): if not value.is_created: raise ma.ValidationError( _("Cannot reference a document that has not been created yet."), ) return self.reference_cls(value.__class__, value.pk) if isinstance(value, dict): if value.keys() != {"cls", "id"}: raise ma.ValidationError( _("Generic reference must have `id` and `cls` fields."), ) try: _id = ObjectId(value["id"]) except ValueError as exc: raise ma.ValidationError(_("Invalid `id` field.")) from exc document_cls = self._document_cls(value["cls"]) return self.reference_cls(document_cls, _id) raise ma.ValidationError(_("Invalid value for generic reference field.")) def _serialize_to_mongo(self, obj): return {"_id": obj.pk, "_cls": obj.document_cls.__name__} def _deserialize_from_mongo(self, value): document_cls = self._document_cls(value["_cls"]) return self.reference_cls(document_cls, value["_id"])
[docs] class EmbeddedField(BaseField, ma.fields.Nested): def __init__(self, embedded_document, *args, **kwargs): """:param embedded_document: Can be a :class:`umongo.embedded_document.EmbeddedDocumentTemplate`, another instance's :class:`umongo.embedded_document.EmbeddedDocumentImplementation` or the embedded document class name. """ # Don't need to pass `nested` attribute given it is overloaded super().__init__(None, *args, **kwargs) # Try to retrieve the template if possible for consistency if not isinstance(embedded_document, str): self.embedded_document = get_template(embedded_document) else: self.embedded_document = embedded_document self._embedded_document_cls = None @property def nested(self): # Overload `nested` attribute to be able to fetch it lazily return self.embedded_document_cls.Schema @nested.setter def nested(self, value): pass @property def embedded_document_cls(self): """Return the instance's :class:`umongo.embedded_document.EmbeddedDocumentImplementation` implementing the `embedded_document` attribute. """ if not self._embedded_document_cls: embedded_document_cls = self.instance.retrieve_embedded_document( self.embedded_document, ) if embedded_document_cls.opts.abstract: raise DocumentDefinitionError( "EmbeddedField doesn't accept abstract embedded document", ) self._embedded_document_cls = embedded_document_cls return self._embedded_document_cls def _serialize(self, value, attr, obj): if value is None: return None return value.dump() def _deserialize(self, value, attr, data, **kwargs): embedded_document_cls = self.embedded_document_cls if isinstance(value, embedded_document_cls): return value if not isinstance(value, dict): raise ma.ValidationError({"_schema": ["Invalid input type."]}) # Handle inheritance deserialization here using `cls` field as hint if embedded_document_cls.opts.offspring and "cls" in value: to_use_cls_name = value.pop("cls") if not any( o for o in embedded_document_cls.opts.offspring if o.__name__ == to_use_cls_name ): raise ma.ValidationError( _("Unknown document `{document}`.").format( document=to_use_cls_name, ), ) try: to_use_cls = ( embedded_document_cls.opts.instance.retrieve_embedded_document( to_use_cls_name, ) ) except NotRegisteredDocumentError as exc: raise ma.ValidationError(str(exc)) from exc return to_use_cls(**value) return embedded_document_cls(**value) def _serialize_to_mongo(self, obj): return obj.to_mongo() def _deserialize_from_mongo(self, value): return self.embedded_document_cls.build_from_mongo(value) def _validate_missing(self, value): # Overload default to handle recursive check super()._validate_missing(value) errors = {} if value is ma.missing: def get_sub_value(_): return ma.missing elif isinstance(value, dict): # value is a dict for deserialization def get_sub_value(key): return value.get(key, ma.missing) elif isinstance(value, self.embedded_document_cls): # value is a valid EmbeddedDocument def get_sub_value(key): return value._data.get(key) else: # value is invalid, just return and let `_deserialize` # raises an error about this return for name, field in self.embedded_document_cls.schema.fields.items(): sub_value = get_sub_value(name) # `_validate_missing` doesn't check for required fields here, so we # can safely skip missing values if sub_value is ma.missing: continue try: field._validate_missing(sub_value) except ma.ValidationError as exc: errors[name] = exc.messages if errors: raise ma.ValidationError(errors)
[docs] def map_to_field(self, mongo_path, path, func): """Apply a function to every field in the schema""" for name, field in self.embedded_document_cls.schema.fields.items(): cur_path = f"{path}.{name}" cur_mongo_path = f"{mongo_path}.{field.attribute or name}" func(cur_mongo_path, cur_path, field) if hasattr(field, "map_to_field"): field.map_to_field(cur_mongo_path, cur_path, func)
[docs] def as_marshmallow_field(self): # Overwrite default `as_marshmallow_field` to handle nesting field_kwargs = self._extract_marshmallow_field_params() nested_ma_schema = self.embedded_document_cls.schema.as_marshmallow_schema() m_field = ma.fields.Nested( nested_ma_schema, metadata=self.metadata, **field_kwargs, ) m_field.error_messages = I18nErrorDict(m_field.error_messages) return m_field
def _required_validate(self, value): value.required_validate()