250 lines
8.3 KiB
Python

"""
Since "Model" word would be very confusing when used in django context, this
module basically makes an alias for it named "Schema" and adds extra whistles to
be able to work with django querysets and managers.
The schema is a bit smarter than a standard pydantic Model because it can handle
dotted attributes and resolver methods. For example::
class UserSchema(User):
name: str
initials: str
boss: str = Field(None, alias="boss.first_name")
@staticmethod
def resolve_name(obj):
return f"{obj.first_name} {obj.last_name}"
"""
import warnings
from typing import (
Any,
Callable,
Dict,
Type,
TypeVar,
Union,
no_type_check,
)
import pydantic
from django.db.models import Manager, QuerySet
from django.db.models.fields.files import FieldFile
from django.template import Variable, VariableDoesNotExist
from pydantic import BaseModel, Field, ValidationInfo, model_validator, validator
from pydantic._internal._model_construction import ModelMetaclass
from pydantic.functional_validators import ModelWrapValidatorHandler
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from typing_extensions import dataclass_transform
from ninja.signature.utils import get_args_names, has_kwargs
from ninja.types import DictStrAny
pydantic_version = list(map(int, pydantic.VERSION.split(".")[:2]))
assert pydantic_version >= [2, 0], "Pydantic 2.0+ required"
__all__ = ["BaseModel", "Field", "validator", "DjangoGetter", "Schema"]
S = TypeVar("S", bound="Schema")
class DjangoGetter:
__slots__ = ("_obj", "_schema_cls", "_context", "__dict__")
def __init__(self, obj: Any, schema_cls: Type[S], context: Any = None):
self._obj = obj
self._schema_cls = schema_cls
self._context = context
def __getattr__(self, key: str) -> Any:
# if key.startswith("__pydantic"):
# return getattr(self._obj, key)
resolver = self._schema_cls._ninja_resolvers.get(key)
if resolver:
value = resolver(getter=self)
else:
if isinstance(self._obj, dict):
if key not in self._obj:
raise AttributeError(key)
value = self._obj[key]
else:
try:
value = getattr(self._obj, key)
except AttributeError:
try:
# value = attrgetter(key)(self._obj)
value = Variable(key).resolve(self._obj)
# TODO: Variable(key) __init__ is actually slower than
# Variable.resolve - so it better be cached
except VariableDoesNotExist as e:
raise AttributeError(key) from e
return self._convert_result(value)
# def get(self, key: Any, default: Any = None) -> Any:
# try:
# return self[key]
# except KeyError:
# return default
def _convert_result(self, result: Any) -> Any:
if isinstance(result, Manager):
return list(result.all())
elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)):
return list(result)
if callable(result):
return result()
elif isinstance(result, FieldFile):
if not result:
return None
return result.url
return result
def __repr__(self) -> str:
return f"<DjangoGetter: {repr(self._obj)}>"
class Resolver:
__slots__ = ("_func", "_static", "_takes_context")
_static: bool
_func: Any
_takes_context: bool
def __init__(self, func: Union[Callable, staticmethod]):
if isinstance(func, staticmethod):
self._static = True
self._func = func.__func__
else:
self._static = False
self._func = func
arg_names = get_args_names(self._func)
self._takes_context = has_kwargs(self._func) or "context" in arg_names
def __call__(self, getter: DjangoGetter) -> Any:
kwargs = {}
if self._takes_context:
kwargs["context"] = getter._context
if self._static:
return self._func(getter._obj, **kwargs)
raise NotImplementedError(
"Non static resolves are not supported yet"
) # pragma: no cover
# return self._func(self._fake_instance(getter), getter._obj)
# def _fake_instance(self, getter: DjangoGetter) -> "Schema":
# """
# Generate a partial schema instance that can be used as the ``self``
# attribute of resolver functions.
# """
# class PartialSchema(Schema):
# def __getattr__(self, key: str) -> Any:
# value = getattr(getter, key)
# field = getter._schema_cls.model_fields[key]
# value = field.validate(value, values={}, loc=key, cls=None)[0]
# return value
# return PartialSchema()
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ResolverMetaclass(ModelMetaclass):
_ninja_resolvers: Dict[str, Resolver]
@no_type_check
def __new__(cls, name, bases, namespace, **kwargs):
resolvers = {}
for base in reversed(bases):
base_resolvers = getattr(base, "_ninja_resolvers", None)
if base_resolvers:
resolvers.update(base_resolvers)
for attr, resolve_func in namespace.items():
if not attr.startswith("resolve_"):
continue
if (
not callable(resolve_func)
# A staticmethod isn't directly callable in Python <=3.9.
and not isinstance(resolve_func, staticmethod)
):
continue # pragma: no cover
resolvers[attr[8:]] = Resolver(resolve_func)
result = super().__new__(cls, name, bases, namespace, **kwargs)
result._ninja_resolvers = resolvers
return result
class NinjaGenerateJsonSchema(GenerateJsonSchema):
def default_schema(self, schema: Any) -> JsonSchemaValue:
# Pydantic default actually renders null's and default_factory's
# which really breaks swagger and django model callable defaults
# so here we completely override behavior
json_schema = self.generate_inner(schema["schema"])
default = None
if "default" in schema and schema["default"] is not None:
default = self.encode_default(schema["default"])
if "$ref" in json_schema:
# Since reference schemas do not support child keys, we wrap the reference schema in a single-case allOf:
result = {"allOf": [json_schema]}
else:
result = json_schema
if default is not None:
result["default"] = default
return result
class Schema(BaseModel, metaclass=ResolverMetaclass):
class Config:
from_attributes = True # aka orm_mode
@model_validator(mode="wrap")
@classmethod
def _run_root_validator(
cls, values: Any, handler: ModelWrapValidatorHandler[S], info: ValidationInfo
) -> Any:
# If Pydantic intends to validate against the __dict__ of the immediate Schema
# object, then we need to call `handler` directly on `values` before the conversion
# to DjangoGetter, since any checks or modifications on DjangoGetter's __dict__
# will not persist to the original object.
forbids_extra = cls.model_config.get("extra") == "forbid"
should_validate_assignment = cls.model_config.get("validate_assignment", False)
if forbids_extra or should_validate_assignment:
handler(values)
values = DjangoGetter(values, cls, info.context)
return handler(values)
@classmethod
def from_orm(cls: Type[S], obj: Any, **kw: Any) -> S:
return cls.model_validate(obj, **kw)
def dict(self, *a: Any, **kw: Any) -> DictStrAny:
"Backward compatibility with pydantic 1.x"
return self.model_dump(*a, **kw)
@classmethod
def json_schema(cls) -> DictStrAny:
return cls.model_json_schema(schema_generator=NinjaGenerateJsonSchema)
@classmethod
def schema(cls) -> DictStrAny: # type: ignore
warnings.warn(
".schema() is deprecated, use .json_schema() instead",
DeprecationWarning,
stacklevel=2,
)
return cls.json_schema()