352 lines
13 KiB
Python
352 lines
13 KiB
Python
import inspect
|
|
import warnings
|
|
from collections import defaultdict, namedtuple
|
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
|
|
|
import pydantic
|
|
from django.http import HttpResponse
|
|
from pydantic.fields import FieldInfo
|
|
from pydantic_core import PydanticUndefined
|
|
from typing_extensions import Annotated, get_args, get_origin
|
|
|
|
from ninja import UploadedFile
|
|
from ninja.compatibility.util import UNION_TYPES
|
|
from ninja.errors import ConfigError
|
|
from ninja.params.models import (
|
|
Body,
|
|
File,
|
|
Form,
|
|
Param,
|
|
Path,
|
|
Query,
|
|
TModel,
|
|
TModels,
|
|
_MultiPartBody,
|
|
)
|
|
from ninja.signature.utils import get_path_param_names, get_typed_signature
|
|
|
|
__all__ = [
|
|
"ViewSignature",
|
|
"is_pydantic_model",
|
|
"is_collection_type",
|
|
"detect_collection_fields",
|
|
]
|
|
|
|
FuncParam = namedtuple(
|
|
"FuncParam", ["name", "alias", "source", "annotation", "is_collection"]
|
|
)
|
|
|
|
|
|
class ViewSignature:
|
|
FLATTEN_PATH_SEP = (
|
|
"\x1e" # ASCII Record Separator. IE: not generally used in query names
|
|
)
|
|
response_arg: Optional[str] = None
|
|
|
|
def __init__(self, path: str, view_func: Callable[..., Any]) -> None:
|
|
self.view_func = view_func
|
|
self.signature = get_typed_signature(self.view_func)
|
|
self.path = path
|
|
self.path_params_names = get_path_param_names(path)
|
|
self.docstring = inspect.cleandoc(view_func.__doc__ or "")
|
|
self.has_kwargs = False
|
|
|
|
self.params = []
|
|
for name, arg in self.signature.parameters.items():
|
|
if name == "request":
|
|
# TODO: maybe better assert that 1st param is request or check by type?
|
|
# maybe even have attribute like `has_request`
|
|
# so that users can ignore passing request if not needed
|
|
continue
|
|
|
|
if arg.kind == arg.VAR_KEYWORD:
|
|
# Skipping **kwargs
|
|
self.has_kwargs = True
|
|
continue
|
|
|
|
if arg.kind == arg.VAR_POSITIONAL:
|
|
# Skipping *args
|
|
continue
|
|
|
|
if arg.annotation is HttpResponse:
|
|
self.response_arg = name
|
|
continue
|
|
|
|
if (
|
|
arg.annotation is inspect.Parameter.empty
|
|
and isinstance(arg.default, type)
|
|
and issubclass(arg.default, pydantic.BaseModel)
|
|
):
|
|
raise ConfigError(
|
|
f"Looks like you are using `{name}={arg.default.__name__}` instead of `{name}: {arg.default.__name__}` (annotation)"
|
|
)
|
|
|
|
func_param = self._get_param_type(name, arg)
|
|
self.params.append(func_param)
|
|
|
|
if hasattr(view_func, "_ninja_contribute_args"):
|
|
# _ninja_contribute_args is a special attribute
|
|
# which allows developers to create custom function params
|
|
# inside decorators or other functions
|
|
for p_name, p_type, p_source in view_func._ninja_contribute_args:
|
|
self.params.append(
|
|
FuncParam(p_name, p_source.alias or p_name, p_source, p_type, False)
|
|
)
|
|
|
|
self.models: TModels = self._create_models()
|
|
|
|
self._validate_view_path_params()
|
|
|
|
def _validate_view_path_params(self) -> None:
|
|
"""verify all path params are present in the path model fields"""
|
|
if self.path_params_names:
|
|
path_model = next(
|
|
(m for m in self.models if m.__ninja_param_source__ == "path"), None
|
|
)
|
|
missing = tuple(
|
|
sorted(
|
|
name
|
|
for name in self.path_params_names
|
|
if not (path_model and name in path_model.__ninja_flatten_map__)
|
|
)
|
|
)
|
|
if missing:
|
|
warnings.warn_explicit(
|
|
UserWarning(
|
|
f"Field(s) {missing} are in the view path, but were not found in the view signature."
|
|
),
|
|
category=None,
|
|
filename=inspect.getfile(self.view_func),
|
|
lineno=inspect.getsourcelines(self.view_func)[1],
|
|
source=None,
|
|
)
|
|
|
|
def _create_models(self) -> TModels:
|
|
params_by_source_cls: Dict[Any, List[FuncParam]] = defaultdict(list)
|
|
for param in self.params:
|
|
param_source_cls = type(param.source)
|
|
params_by_source_cls[param_source_cls].append(param)
|
|
|
|
is_multipart_response_with_body = Body in params_by_source_cls and (
|
|
File in params_by_source_cls or Form in params_by_source_cls
|
|
)
|
|
if is_multipart_response_with_body:
|
|
params_by_source_cls[_MultiPartBody] = params_by_source_cls.pop(Body)
|
|
|
|
result = []
|
|
for param_cls, args in params_by_source_cls.items():
|
|
cls_name: str = param_cls.__name__ + "Params"
|
|
attrs = {i.name: i.source for i in args}
|
|
attrs["__ninja_param_source__"] = param_cls._param_source()
|
|
attrs["__ninja_flatten_map_reverse__"] = {}
|
|
|
|
if attrs["__ninja_param_source__"] == "file":
|
|
pass
|
|
|
|
elif attrs["__ninja_param_source__"] in {
|
|
"form",
|
|
"query",
|
|
"header",
|
|
"cookie",
|
|
"path",
|
|
}:
|
|
flatten_map = self._args_flatten_map(args)
|
|
attrs["__ninja_flatten_map__"] = flatten_map
|
|
attrs["__ninja_flatten_map_reverse__"] = {
|
|
v: (k,) for k, v in flatten_map.items()
|
|
}
|
|
|
|
else:
|
|
assert attrs["__ninja_param_source__"] == "body"
|
|
if is_multipart_response_with_body:
|
|
attrs["__ninja_body_params__"] = {
|
|
i.alias: i.annotation for i in args
|
|
}
|
|
else:
|
|
# ::TODO:: this is still sus. build some test cases
|
|
attrs["__read_from_single_attr__"] = (
|
|
args[0].name if len(args) == 1 else None
|
|
)
|
|
|
|
# adding annotations
|
|
attrs["__annotations__"] = {i.name: i.annotation for i in args}
|
|
|
|
# collection fields:
|
|
attrs["__ninja_collection_fields__"] = detect_collection_fields(
|
|
args, attrs.get("__ninja_flatten_map__", {})
|
|
)
|
|
|
|
base_cls = param_cls._model
|
|
model_cls = type(cls_name, (base_cls,), attrs)
|
|
# TODO: https://pydantic-docs.helpmanual.io/usage/models/#dynamic-model-creation - check if anything special in create_model method that I did not use
|
|
result.append(model_cls)
|
|
return result
|
|
|
|
def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]:
|
|
flatten_map = {}
|
|
arg_names: Any = {}
|
|
for arg in args:
|
|
if is_pydantic_model(arg.annotation):
|
|
for name, path in self._model_flatten_map(arg.annotation, arg.alias):
|
|
if name in flatten_map:
|
|
raise ConfigError(
|
|
f"Duplicated name: '{name}' in params: '{arg_names[name]}' & '{arg.name}'"
|
|
)
|
|
flatten_map[name] = tuple(path.split(self.FLATTEN_PATH_SEP))
|
|
arg_names[name] = arg.name
|
|
else:
|
|
name = arg.alias
|
|
if name in flatten_map:
|
|
raise ConfigError(
|
|
f"Duplicated name: '{name}' also in '{arg_names[name]}'"
|
|
)
|
|
flatten_map[name] = (name,)
|
|
arg_names[name] = name
|
|
|
|
return flatten_map
|
|
|
|
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
|
|
field: FieldInfo
|
|
for attr, field in model.model_fields.items():
|
|
field_name = field.alias or attr
|
|
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
|
|
if is_pydantic_model(field.annotation):
|
|
yield from self._model_flatten_map(field.annotation, name) # type: ignore
|
|
else:
|
|
yield field_name, name
|
|
|
|
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
|
|
# _EMPTY = self.signature.empty
|
|
annotation = arg.annotation
|
|
default = arg.default
|
|
|
|
if get_origin(annotation) is Annotated:
|
|
args = get_args(annotation)
|
|
if isinstance(args[1], Param):
|
|
prev_default = default
|
|
annotation, default = args
|
|
if prev_default != self.signature.empty:
|
|
default.default = prev_default
|
|
|
|
if annotation == self.signature.empty:
|
|
if default == self.signature.empty:
|
|
annotation = str
|
|
else:
|
|
if isinstance(default, Param):
|
|
annotation = type(default.default)
|
|
else:
|
|
annotation = type(default)
|
|
|
|
if annotation == PydanticUndefined.__class__:
|
|
# TODO: ^ check why is that so
|
|
annotation = str
|
|
|
|
if annotation == type(None) or annotation == type(Ellipsis): # noqa
|
|
annotation = str
|
|
|
|
is_collection = is_collection_type(annotation)
|
|
|
|
if annotation == UploadedFile or (
|
|
is_collection and annotation.__args__[0] == UploadedFile
|
|
):
|
|
# People often forgot to mark UploadedFile as a File, so we better assign it automatically
|
|
if default == self.signature.empty or default is None:
|
|
default = default == self.signature.empty and ... or default
|
|
return FuncParam(name, name, File(default), annotation, is_collection)
|
|
|
|
# 1) if type of the param is defined as one of the Param's subclasses - we just use that definition
|
|
if isinstance(default, Param):
|
|
param_source = default
|
|
|
|
# 2) if param name is a part of the path parameter
|
|
elif name in self.path_params_names:
|
|
assert (
|
|
default == self.signature.empty
|
|
), f"'{name}' is a path param, default not allowed"
|
|
param_source = Path(...)
|
|
|
|
# 3) if param is a collection, or annotation is part of pydantic model:
|
|
elif is_collection or is_pydantic_model(annotation):
|
|
if default == self.signature.empty:
|
|
param_source = Body(...)
|
|
else:
|
|
param_source = Body(default)
|
|
|
|
# 4) the last case is query param
|
|
else:
|
|
if default == self.signature.empty:
|
|
param_source = Query(...)
|
|
else:
|
|
param_source = Query(default)
|
|
|
|
return FuncParam(
|
|
name, param_source.alias or name, param_source, annotation, is_collection
|
|
)
|
|
|
|
|
|
def is_pydantic_model(cls: Any) -> bool:
|
|
try:
|
|
if get_origin(cls) in UNION_TYPES:
|
|
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
|
|
return issubclass(cls, pydantic.BaseModel)
|
|
except TypeError:
|
|
return False
|
|
|
|
|
|
def is_collection_type(annotation: Any) -> bool:
|
|
origin = get_origin(annotation)
|
|
|
|
if origin in UNION_TYPES:
|
|
for arg in get_args(annotation):
|
|
if is_collection_type(arg):
|
|
return True
|
|
return False
|
|
|
|
collection_types = (List, list, set, tuple)
|
|
if origin is None:
|
|
return (
|
|
isinstance(annotation, collection_types)
|
|
if not isinstance(annotation, type)
|
|
else issubclass(annotation, collection_types)
|
|
)
|
|
else:
|
|
return origin in collection_types # TODO: I guess we should handle only list
|
|
|
|
|
|
def detect_collection_fields(
|
|
args: List[FuncParam], flatten_map: Dict[str, Tuple[str, ...]]
|
|
) -> List[str]:
|
|
"""
|
|
Django QueryDict has values that are always lists, so we need to help django ninja to understand
|
|
better the input parameters if it's a list or a single value
|
|
This method detects attributes that should be treated by ninja as lists and returns this list as a result
|
|
"""
|
|
result = [i.alias or i.name for i in args if i.is_collection]
|
|
|
|
if flatten_map:
|
|
args_d = {arg.alias: arg for arg in args}
|
|
for path in (p for p in flatten_map.values() if len(p) > 1):
|
|
annotation_or_field: Any = args_d[path[0]].annotation
|
|
for attr in path[1:]:
|
|
if hasattr(annotation_or_field, "annotation"):
|
|
annotation_or_field = annotation_or_field.annotation
|
|
annotation_or_field = next(
|
|
(
|
|
a
|
|
for a in annotation_or_field.model_fields.values()
|
|
if a.alias == attr
|
|
),
|
|
annotation_or_field.model_fields.get(attr),
|
|
) # pragma: no cover
|
|
|
|
annotation_or_field = getattr(
|
|
annotation_or_field, "outer_type_", annotation_or_field
|
|
)
|
|
|
|
# if hasattr(annotation_or_field, "annotation"):
|
|
annotation_or_field = annotation_or_field.annotation
|
|
|
|
if is_collection_type(annotation_or_field):
|
|
result.append(path[-1])
|
|
return result
|