"""umongo fields"""
import collections
import datetime as dt
from bson import DBRef, ObjectId, Decimal128
import marshmallow as ma
# from .registerer import retrieve_document
from .document import DocumentImplementation
from .exceptions import NotRegisteredDocumentError, DocumentDefinitionError
from .template import get_template
from .data_objects import Reference, List, Dict
from . import marshmallow_bonus as ma_bonus_fields
from .abstract import BaseField, I18nErrorDict
from .i18n import gettext as _
__all__ = (
# 'RawField',
# 'MappingField',
# 'TupleField',
'StringField',
'UUIDField',
'NumberField',
'IntegerField',
'DecimalField',
'BooleanField',
'FloatField',
'DateTimeField',
'NaiveDateTimeField',
'AwareDateTimeField',
# 'TimeField',
'DateField',
# 'TimeDeltaField',
'UrlField',
'URLField',
'EmailField',
'StrField',
'BoolField',
'IntField',
'DictField',
'ListField',
'ConstantField',
# 'PluckField'
'ObjectIdField',
'ReferenceField',
'GenericReferenceField',
'EmbeddedField'
)
# Republish supported marshmallow fields
# class RawField(BaseField, ma.fields.Raw):
# pass
[docs]class StringField(BaseField, ma.fields.String):
pass
[docs]class UUIDField(BaseField, ma.fields.UUID):
pass
[docs]class NumberField(BaseField, ma.fields.Number):
pass
[docs]class IntegerField(BaseField, ma.fields.Integer):
pass
[docs]class DecimalField(BaseField, ma.fields.Decimal):
def _serialize_to_mongo(self, obj):
return Decimal128(obj)
def _deserialize_from_mongo(self, value):
return value.to_decimal()
[docs]class BooleanField(BaseField, ma.fields.Boolean):
pass
[docs]class FloatField(BaseField, ma.fields.Float):
pass
def _round_to_millisecond(datetime):
"""Round a datetime to millisecond precision
MongoDB stores datetimes with a millisecond precision.
For consistency, use the same precision in the object representation.
"""
microseconds = round(datetime.microsecond, -3)
if microseconds == 1000000:
return datetime.replace(microsecond=0) + dt.timedelta(seconds=1)
return datetime.replace(microsecond=microseconds)
[docs]class DateTimeField(BaseField, ma.fields.DateTime):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, dt.datetime):
ret = value
else:
ret = super()._deserialize(value, attr, data, **kwargs)
return _round_to_millisecond(ret)
[docs]class NaiveDateTimeField(BaseField, ma.fields.NaiveDateTime):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, dt.datetime):
ret = value
else:
ret = super()._deserialize(value, attr, data, **kwargs)
return _round_to_millisecond(ret)
[docs]class AwareDateTimeField(BaseField, ma.fields.AwareDateTime):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, dt.datetime):
ret = value
else:
ret = super()._deserialize(value, attr, data, **kwargs)
return _round_to_millisecond(ret)
def _deserialize_from_mongo(self, value):
value = value.replace(tzinfo=dt.timezone.utc)
if self.default_timezone is not None:
value = value.astimezone(self.default_timezone)
return value
# class TimeField(BaseField, ma.fields.Time):
# pass
[docs]class DateField(BaseField, ma.fields.Date):
"""This field converts a date to a datetime to store it as a BSON Date"""
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, dt.date):
return value
return super()._deserialize(value, attr, data)
def _serialize_to_mongo(self, obj):
return dt.datetime(obj.year, obj.month, obj.day)
def _deserialize_from_mongo(self, value):
return value.date()
# class TimeDeltaField(BaseField, ma.fields.TimeDelta):
# pass
[docs]class UrlField(BaseField, ma.fields.Url):
pass
[docs]class EmailField(BaseField, ma.fields.Email):
pass
[docs]class ConstantField(BaseField, ma.fields.Constant):
pass
[docs]class DictField(BaseField, ma.fields.Dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def cast_value_or_callable(key_field, value_field, value):
if value is ma.missing:
return ma.missing
if callable(value):
return lambda: Dict(key_field, value_field, value())
return Dict(key_field, value_field, value)
self.default = cast_value_or_callable(self.key_field, self.value_field, self.default)
self.missing = cast_value_or_callable(self.key_field, self.value_field, self.missing)
def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
return Dict(self.key_field, self.value_field, value)
def _serialize_to_mongo(self, obj):
if obj is None:
return ma.missing
return {
self.key_field.serialize_to_mongo(k) if self.key_field else k:
self.value_field.serialize_to_mongo(v) if self.value_field else v
for k, v in obj.items()
}
def _deserialize_from_mongo(self, value):
if value:
return Dict(
self.key_field,
self.value_field,
{
self.key_field.deserialize_from_mongo(k) if self.key_field else k:
self.value_field.deserialize_from_mongo(v) if self.value_field else v
for k, v in value.items()
}
)
return Dict(self.key_field, self.value_field)
[docs] def as_marshmallow_field(self):
field_kwargs = self._extract_marshmallow_field_params()
if self.value_field:
inner_ma_field = self.value_field.as_marshmallow_field()
else:
inner_ma_field = None
m_field = ma.fields.Dict(
self.key_field, inner_ma_field, metadata=self.metadata, **field_kwargs)
m_field.error_messages = I18nErrorDict(m_field.error_messages)
return m_field
def _required_validate(self, value):
if not hasattr(self.value_field, '_required_validate'):
return
required_validate = self.value_field._required_validate
errors = collections.defaultdict(dict)
for key, val in value.items():
try:
required_validate(val)
except ma.ValidationError as exc:
errors[key]["value"] = exc.messages
if errors:
raise ma.ValidationError(errors)
[docs]class ListField(BaseField, ma.fields.List):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def cast_value_or_callable(inner, value):
if value is ma.missing:
return ma.missing
if callable(value):
return lambda: List(inner, value())
return List(inner, value)
self.default = cast_value_or_callable(self.inner, self.default)
self.missing = cast_value_or_callable(self.inner, self.missing)
def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
return List(self.inner, value)
def _serialize_to_mongo(self, obj):
if obj is None:
return ma.missing
return [self.inner.serialize_to_mongo(each) for each in obj]
def _deserialize_from_mongo(self, value):
if value:
return List(
self.inner,
[self.inner.deserialize_from_mongo(each) for each in value]
)
return List(self.inner)
[docs] def map_to_field(self, mongo_path, path, func):
"""Apply a function to every field in the schema
"""
func(mongo_path, path, self.inner)
if hasattr(self.inner, 'map_to_field'):
self.inner.map_to_field(mongo_path, path, func)
[docs] def as_marshmallow_field(self):
field_kwargs = self._extract_marshmallow_field_params()
inner_ma_field = self.inner.as_marshmallow_field()
m_field = ma.fields.List(inner_ma_field, metadata=self.metadata, **field_kwargs)
m_field.error_messages = I18nErrorDict(m_field.error_messages)
return m_field
def _required_validate(self, value):
if not hasattr(self.inner, '_required_validate'):
return
required_validate = self.inner._required_validate
errors = {}
for i, sub_value in enumerate(value):
try:
required_validate(sub_value)
except ma.ValidationError as exc:
errors[i] = exc.messages
if errors:
raise ma.ValidationError(errors)
# Aliases
URLField = UrlField
StrField = StringField
BoolField = BooleanField
IntField = IntegerField
[docs]class ObjectIdField(BaseField, ma_bonus_fields.ObjectId):
pass
[docs]class ReferenceField(BaseField, ma_bonus_fields.Reference):
def __init__(self, document, *args, reference_cls=Reference, **kwargs):
"""
:param document: Can be a :class:`umongo.embedded_document.DocumentTemplate`,
another instance's :class:`umongo.embedded_document.DocumentImplementation` or
the embedded document class name.
.. warning:: The referenced document's _id must be an `ObjectId`.
"""
super().__init__(*args, **kwargs)
# TODO : check document_cls is implementation or string
self.reference_cls = reference_cls
# Can be the Template, Template's name or another Implementation
if not isinstance(document, str):
self.document = get_template(document)
else:
self.document = document
self._document_cls = None
self._document_implementation_cls = DocumentImplementation
@property
def document_cls(self):
"""
Return the instance's :class:`umongo.embedded_document.DocumentImplementation`
implementing the `document` attribute.
"""
if not self._document_cls:
self._document_cls = self.instance.retrieve_document(self.document)
return self._document_cls
def _deserialize(self, value, attr, data, **kwargs):
if value is None:
return None
if isinstance(value, DBRef):
if self.document_cls.collection.name != value.collection:
raise ma.ValidationError(_("DBRef must be on collection `{collection}`.").format(
self.document_cls.collection.name))
value = value.id
elif isinstance(value, Reference):
if value.document_cls != self.document_cls:
raise ma.ValidationError(_("`{document}` reference expected.").format(
document=self.document_cls.__name__))
if not isinstance(value, self.reference_cls):
value = self.reference_cls(value.document_cls, value.pk)
return value
elif isinstance(value, self.document_cls):
if not value.is_created:
raise ma.ValidationError(
_("Cannot reference a document that has not been created yet."))
value = value.pk
elif isinstance(value, self._document_implementation_cls):
raise ma.ValidationError(_("`{document}` reference expected.").format(
document=self.document_cls.__name__))
value = super()._deserialize(value, attr, data, **kwargs)
return self.reference_cls(self.document_cls, value)
def _serialize_to_mongo(self, obj):
return obj.pk
def _deserialize_from_mongo(self, value):
return self.reference_cls(self.document_cls, value)
[docs]class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference):
def __init__(self, *args, reference_cls=Reference, **kwargs):
super().__init__(*args, **kwargs)
self.reference_cls = reference_cls
self._document_implementation_cls = DocumentImplementation
def _document_cls(self, class_name):
try:
return self.instance.retrieve_document(class_name)
except NotRegisteredDocumentError:
raise ma.ValidationError(_('Unknown document `{document}`.').format(
document=class_name))
def _serialize(self, value, attr, obj):
if value is None:
return None
return {'id': str(value.pk), 'cls': value.document_cls.__name__}
def _deserialize(self, value, attr, data, **kwargs):
if value is None:
return None
if isinstance(value, Reference):
if not isinstance(value, self.reference_cls):
value = self.reference_cls(value.document_cls, value.pk)
return value
if isinstance(value, self._document_implementation_cls):
if not value.is_created:
raise ma.ValidationError(
_("Cannot reference a document that has not been created yet."))
return self.reference_cls(value.__class__, value.pk)
if isinstance(value, dict):
if value.keys() != {'cls', 'id'}:
raise ma.ValidationError(_("Generic reference must have `id` and `cls` fields."))
try:
_id = ObjectId(value['id'])
except ValueError:
raise ma.ValidationError(_("Invalid `id` field."))
document_cls = self._document_cls(value['cls'])
return self.reference_cls(document_cls, _id)
raise ma.ValidationError(_("Invalid value for generic reference field."))
def _serialize_to_mongo(self, obj):
return {'_id': obj.pk, '_cls': obj.document_cls.__name__}
def _deserialize_from_mongo(self, value):
document_cls = self._document_cls(value['_cls'])
return self.reference_cls(document_cls, value['_id'])
[docs]class EmbeddedField(BaseField, ma.fields.Nested):
def __init__(self, embedded_document, *args, **kwargs):
"""
:param embedded_document: Can be a
:class:`umongo.embedded_document.EmbeddedDocumentTemplate`,
another instance's :class:`umongo.embedded_document.EmbeddedDocumentImplementation`
or the embedded document class name.
"""
# Don't need to pass `nested` attribute given it is overloaded
super().__init__(None, *args, **kwargs)
# Try to retrieve the template if possible for consistency
if not isinstance(embedded_document, str):
self.embedded_document = get_template(embedded_document)
else:
self.embedded_document = embedded_document
self._embedded_document_cls = None
@property
def nested(self):
# Overload `nested` attribute to be able to fetch it lazily
return self.embedded_document_cls.Schema
@nested.setter
def nested(self, value):
pass
@property
def embedded_document_cls(self):
"""
Return the instance's :class:`umongo.embedded_document.EmbeddedDocumentImplementation`
implementing the `embedded_document` attribute.
"""
if not self._embedded_document_cls:
embedded_document_cls = self.instance.retrieve_embedded_document(self.embedded_document)
if embedded_document_cls.opts.abstract:
raise DocumentDefinitionError(
"EmbeddedField doesn't accept abstract embedded document"
)
self._embedded_document_cls = embedded_document_cls
return self._embedded_document_cls
def _serialize(self, value, attr, obj):
if value is None:
return None
return value.dump()
def _deserialize(self, value, attr, data, **kwargs):
embedded_document_cls = self.embedded_document_cls
if isinstance(value, embedded_document_cls):
return value
if not isinstance(value, dict):
raise ma.ValidationError({'_schema': ['Invalid input type.']})
# Handle inheritance deserialization here using `cls` field as hint
if embedded_document_cls.opts.offspring and 'cls' in value:
to_use_cls_name = value.pop('cls')
if not any(o for o in embedded_document_cls.opts.offspring
if o.__name__ == to_use_cls_name):
raise ma.ValidationError(_('Unknown document `{document}`.').format(
document=to_use_cls_name))
try:
to_use_cls = embedded_document_cls.opts.instance.retrieve_embedded_document(
to_use_cls_name)
except NotRegisteredDocumentError as exc:
raise ma.ValidationError(str(exc))
return to_use_cls(**value)
return embedded_document_cls(**value)
def _serialize_to_mongo(self, obj):
return obj.to_mongo()
def _deserialize_from_mongo(self, value):
return self.embedded_document_cls.build_from_mongo(value)
def _validate_missing(self, value):
# Overload default to handle recursive check
super()._validate_missing(value)
errors = {}
if value is ma.missing:
def get_sub_value(_):
return ma.missing
elif isinstance(value, dict):
# value is a dict for deserialization
def get_sub_value(key):
return value.get(key, ma.missing)
elif isinstance(value, self.embedded_document_cls):
# value is a valid EmbeddedDocument
def get_sub_value(key):
return value._data.get(key)
else:
# value is invalid, just return and let `_deserialize`
# raises an error about this
return
for name, field in self.embedded_document_cls.schema.fields.items():
sub_value = get_sub_value(name)
# `_validate_missing` doesn't check for required fields here, so we
# can safely skip missing values
if sub_value is ma.missing:
continue
try:
field._validate_missing(sub_value)
except ma.ValidationError as exc:
errors[name] = exc.messages
if errors:
raise ma.ValidationError(errors)
[docs] def map_to_field(self, mongo_path, path, func):
"""Apply a function to every field in the schema"""
for name, field in self.embedded_document_cls.schema.fields.items():
cur_path = '%s.%s' % (path, name)
cur_mongo_path = '%s.%s' % (mongo_path, field.attribute or name)
func(cur_mongo_path, cur_path, field)
if hasattr(field, 'map_to_field'):
field.map_to_field(cur_mongo_path, cur_path, func)
[docs] def as_marshmallow_field(self):
# Overwrite default `as_marshmallow_field` to handle nesting
field_kwargs = self._extract_marshmallow_field_params()
nested_ma_schema = self.embedded_document_cls.schema.as_marshmallow_schema()
m_field = ma.fields.Nested(nested_ma_schema, metadata=self.metadata, **field_kwargs)
m_field.error_messages = I18nErrorDict(m_field.error_messages)
return m_field
def _required_validate(self, value):
value.required_validate()