# -*- coding: utf-8 -*-
import marshmallow as ma
from marshmallow.compat import with_metaclass, iteritems
from .convert import ModelConverter
from .fields import get_primary_keys
[docs]class TableSchemaOpts(ma.SchemaOpts):
"""Options class for `TableSchema`.
Adds the following options:
- ``table``: The SQLAlchemy table to generate the `Schema` from (required).
- ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy table to
marshmallow fields.
- ``include_fk``: Whether to include foreign fields; defaults to `False`.
"""
def __init__(self, meta, *args, **kwargs):
super(TableSchemaOpts, self).__init__(meta, *args, **kwargs)
self.table = getattr(meta, 'table', None)
self.model_converter = getattr(meta, 'model_converter', ModelConverter)
self.include_fk = getattr(meta, 'include_fk', False)
[docs]class ModelSchemaOpts(ma.SchemaOpts):
"""Options class for `ModelSchema`.
Adds the following options:
- ``model``: The SQLAlchemy model to generate the `Schema` from (required).
- ``sqla_session``: SQLAlchemy session to be used for deserialization. This is optional; you
can also pass a session to the Schema's `load` method.
- ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy model to
marshmallow fields.
- ``include_fk``: Whether to include foreign fields; defaults to `False`.
"""
def __init__(self, meta, *args, **kwargs):
super(ModelSchemaOpts, self).__init__(meta, *args, **kwargs)
self.model = getattr(meta, 'model', None)
self.sqla_session = getattr(meta, 'sqla_session', None)
self.model_converter = getattr(meta, 'model_converter', ModelConverter)
self.include_fk = getattr(meta, 'include_fk', False)
class SchemaMeta(ma.schema.SchemaMeta):
"""Metaclass for `ModelSchema`."""
# override SchemaMeta
@classmethod
def get_declared_fields(mcs, klass, cls_fields, inherited_fields, dict_cls):
"""Updates declared fields with fields converted from the SQLAlchemy model
passed as the `model` class Meta option.
"""
declared_fields = dict_cls()
opts = klass.opts
Converter = opts.model_converter
converter = Converter(schema_cls=klass)
base_fields = super(SchemaMeta, mcs).get_declared_fields(
klass, cls_fields, inherited_fields, dict_cls
)
declared_fields = mcs.get_fields(converter, opts, base_fields, dict_cls)
declared_fields.update(base_fields)
return declared_fields
@classmethod
def get_fields(mcs, converter, base_fields, opts):
pass
class TableSchemaMeta(SchemaMeta):
@classmethod
def get_fields(mcs, converter, opts, base_fields, dict_cls):
if opts.table is not None:
return converter.fields_for_table(
opts.table,
fields=opts.fields,
exclude=opts.exclude,
include_fk=opts.include_fk,
base_fields=base_fields,
dict_cls=dict_cls,
)
return dict_cls()
class ModelSchemaMeta(SchemaMeta):
@classmethod
def get_fields(mcs, converter, opts, base_fields, dict_cls):
if opts.model is not None:
return converter.fields_for_model(
opts.model,
fields=opts.fields,
exclude=opts.exclude,
include_fk=opts.include_fk,
base_fields=base_fields,
dict_cls=dict_cls,
)
return dict_cls()
[docs]class TableSchema(with_metaclass(TableSchemaMeta, ma.Schema)):
"""Base class for SQLAlchemy model-based Schemas.
Example: ::
from marshmallow_sqlalchemy import TableSchema
from mymodels import engine, users
class UserSchema(TableSchema):
class Meta:
table = users
schema = UserSchema()
select = users.select().limit(1)
user = engine.execute(select).fetchone()
serialized = schema.dump(user).data
"""
OPTIONS_CLASS = TableSchemaOpts
[docs]class ModelSchema(with_metaclass(ModelSchemaMeta, ma.Schema)):
"""Base class for SQLAlchemy model-based Schemas.
Example: ::
from marshmallow_sqlalchemy import ModelSchema
from mymodels import User, session
class UserSchema(ModelSchema):
class Meta:
model = User
schema = UserSchema()
user = schema.load({'name': 'Bill'}, session=session)
existing_user = schema.load({'name': 'Bill'}, instance=User.query.first())
:param session: Optional SQLAlchemy session; may be overridden in `load.`
:param instance: Optional existing instance to modify; may be overridden in `load`.
"""
OPTIONS_CLASS = ModelSchemaOpts
@property
def session(self):
return self._session or self.opts.sqla_session
@session.setter
def session(self, session):
self._session = session
def __init__(self, *args, **kwargs):
self._session = kwargs.pop('session', None)
self.instance = kwargs.pop('instance', None)
super(ModelSchema, self).__init__(*args, **kwargs)
[docs] def get_instance(self, data):
"""Retrieve an existing record by primary key(s)."""
props = get_primary_keys(self.opts.model)
filters = {
prop.key: data.get(prop.key)
for prop in props
}
if None not in filters.values():
return self.session.query(
self.opts.model
).filter_by(
**filters
).first()
return None
[docs] @ma.post_load
def make_instance(self, data):
"""Deserialize data to an instance of the model. Update an existing row
if specified in `self.instance` or loaded by primary key(s) in the data;
else create a new row.
:param data: Data to deserialize.
"""
instance = self.instance or self.get_instance(data)
if instance is not None:
for key, value in iteritems(data):
setattr(instance, key, value)
return instance
return self.opts.model(**data)
[docs] def load(self, data, session=None, instance=None, *args, **kwargs):
"""Deserialize data to internal representation.
:param session: Optional SQLAlchemy session.
:param instance: Optional existing instance to modify.
"""
self._session = session or self._session
if not self.session:
raise ValueError('Deserialization requires a session')
self.instance = instance or self.instance
try:
return super(ModelSchema, self).load(data, *args, **kwargs)
finally:
self.instance = None
[docs] def validate(self, data, session=None, *args, **kwargs):
self._session = session or self._session
if not self.session:
raise ValueError('Validation requires a session')
return super(ModelSchema, self).validate(data, *args, **kwargs)