352 lines
13 KiB
Python

import inspect
import warnings
from collections import defaultdict, namedtuple
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import pydantic
from django.http import HttpResponse
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import Annotated, get_args, get_origin
from ninja import UploadedFile
from ninja.compatibility.util import UNION_TYPES
from ninja.errors import ConfigError
from ninja.params.models import (
Body,
File,
Form,
Param,
Path,
Query,
TModel,
TModels,
_MultiPartBody,
)
from ninja.signature.utils import get_path_param_names, get_typed_signature
__all__ = [
"ViewSignature",
"is_pydantic_model",
"is_collection_type",
"detect_collection_fields",
]
FuncParam = namedtuple(
"FuncParam", ["name", "alias", "source", "annotation", "is_collection"]
)
class ViewSignature:
FLATTEN_PATH_SEP = (
"\x1e" # ASCII Record Separator. IE: not generally used in query names
)
response_arg: Optional[str] = None
def __init__(self, path: str, view_func: Callable[..., Any]) -> None:
self.view_func = view_func
self.signature = get_typed_signature(self.view_func)
self.path = path
self.path_params_names = get_path_param_names(path)
self.docstring = inspect.cleandoc(view_func.__doc__ or "")
self.has_kwargs = False
self.params = []
for name, arg in self.signature.parameters.items():
if name == "request":
# TODO: maybe better assert that 1st param is request or check by type?
# maybe even have attribute like `has_request`
# so that users can ignore passing request if not needed
continue
if arg.kind == arg.VAR_KEYWORD:
# Skipping **kwargs
self.has_kwargs = True
continue
if arg.kind == arg.VAR_POSITIONAL:
# Skipping *args
continue
if arg.annotation is HttpResponse:
self.response_arg = name
continue
if (
arg.annotation is inspect.Parameter.empty
and isinstance(arg.default, type)
and issubclass(arg.default, pydantic.BaseModel)
):
raise ConfigError(
f"Looks like you are using `{name}={arg.default.__name__}` instead of `{name}: {arg.default.__name__}` (annotation)"
)
func_param = self._get_param_type(name, arg)
self.params.append(func_param)
if hasattr(view_func, "_ninja_contribute_args"):
# _ninja_contribute_args is a special attribute
# which allows developers to create custom function params
# inside decorators or other functions
for p_name, p_type, p_source in view_func._ninja_contribute_args:
self.params.append(
FuncParam(p_name, p_source.alias or p_name, p_source, p_type, False)
)
self.models: TModels = self._create_models()
self._validate_view_path_params()
def _validate_view_path_params(self) -> None:
"""verify all path params are present in the path model fields"""
if self.path_params_names:
path_model = next(
(m for m in self.models if m.__ninja_param_source__ == "path"), None
)
missing = tuple(
sorted(
name
for name in self.path_params_names
if not (path_model and name in path_model.__ninja_flatten_map__)
)
)
if missing:
warnings.warn_explicit(
UserWarning(
f"Field(s) {missing} are in the view path, but were not found in the view signature."
),
category=None,
filename=inspect.getfile(self.view_func),
lineno=inspect.getsourcelines(self.view_func)[1],
source=None,
)
def _create_models(self) -> TModels:
params_by_source_cls: Dict[Any, List[FuncParam]] = defaultdict(list)
for param in self.params:
param_source_cls = type(param.source)
params_by_source_cls[param_source_cls].append(param)
is_multipart_response_with_body = Body in params_by_source_cls and (
File in params_by_source_cls or Form in params_by_source_cls
)
if is_multipart_response_with_body:
params_by_source_cls[_MultiPartBody] = params_by_source_cls.pop(Body)
result = []
for param_cls, args in params_by_source_cls.items():
cls_name: str = param_cls.__name__ + "Params"
attrs = {i.name: i.source for i in args}
attrs["__ninja_param_source__"] = param_cls._param_source()
attrs["__ninja_flatten_map_reverse__"] = {}
if attrs["__ninja_param_source__"] == "file":
pass
elif attrs["__ninja_param_source__"] in {
"form",
"query",
"header",
"cookie",
"path",
}:
flatten_map = self._args_flatten_map(args)
attrs["__ninja_flatten_map__"] = flatten_map
attrs["__ninja_flatten_map_reverse__"] = {
v: (k,) for k, v in flatten_map.items()
}
else:
assert attrs["__ninja_param_source__"] == "body"
if is_multipart_response_with_body:
attrs["__ninja_body_params__"] = {
i.alias: i.annotation for i in args
}
else:
# ::TODO:: this is still sus. build some test cases
attrs["__read_from_single_attr__"] = (
args[0].name if len(args) == 1 else None
)
# adding annotations
attrs["__annotations__"] = {i.name: i.annotation for i in args}
# collection fields:
attrs["__ninja_collection_fields__"] = detect_collection_fields(
args, attrs.get("__ninja_flatten_map__", {})
)
base_cls = param_cls._model
model_cls = type(cls_name, (base_cls,), attrs)
# TODO: https://pydantic-docs.helpmanual.io/usage/models/#dynamic-model-creation - check if anything special in create_model method that I did not use
result.append(model_cls)
return result
def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]:
flatten_map = {}
arg_names: Any = {}
for arg in args:
if is_pydantic_model(arg.annotation):
for name, path in self._model_flatten_map(arg.annotation, arg.alias):
if name in flatten_map:
raise ConfigError(
f"Duplicated name: '{name}' in params: '{arg_names[name]}' & '{arg.name}'"
)
flatten_map[name] = tuple(path.split(self.FLATTEN_PATH_SEP))
arg_names[name] = arg.name
else:
name = arg.alias
if name in flatten_map:
raise ConfigError(
f"Duplicated name: '{name}' also in '{arg_names[name]}'"
)
flatten_map[name] = (name,)
arg_names[name] = name
return flatten_map
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
field: FieldInfo
for attr, field in model.model_fields.items():
field_name = field.alias or attr
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
if is_pydantic_model(field.annotation):
yield from self._model_flatten_map(field.annotation, name) # type: ignore
else:
yield field_name, name
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
# _EMPTY = self.signature.empty
annotation = arg.annotation
default = arg.default
if get_origin(annotation) is Annotated:
args = get_args(annotation)
if isinstance(args[1], Param):
prev_default = default
annotation, default = args
if prev_default != self.signature.empty:
default.default = prev_default
if annotation == self.signature.empty:
if default == self.signature.empty:
annotation = str
else:
if isinstance(default, Param):
annotation = type(default.default)
else:
annotation = type(default)
if annotation == PydanticUndefined.__class__:
# TODO: ^ check why is that so
annotation = str
if annotation == type(None) or annotation == type(Ellipsis): # noqa
annotation = str
is_collection = is_collection_type(annotation)
if annotation == UploadedFile or (
is_collection and annotation.__args__[0] == UploadedFile
):
# People often forgot to mark UploadedFile as a File, so we better assign it automatically
if default == self.signature.empty or default is None:
default = default == self.signature.empty and ... or default
return FuncParam(name, name, File(default), annotation, is_collection)
# 1) if type of the param is defined as one of the Param's subclasses - we just use that definition
if isinstance(default, Param):
param_source = default
# 2) if param name is a part of the path parameter
elif name in self.path_params_names:
assert (
default == self.signature.empty
), f"'{name}' is a path param, default not allowed"
param_source = Path(...)
# 3) if param is a collection, or annotation is part of pydantic model:
elif is_collection or is_pydantic_model(annotation):
if default == self.signature.empty:
param_source = Body(...)
else:
param_source = Body(default)
# 4) the last case is query param
else:
if default == self.signature.empty:
param_source = Query(...)
else:
param_source = Query(default)
return FuncParam(
name, param_source.alias or name, param_source, annotation, is_collection
)
def is_pydantic_model(cls: Any) -> bool:
try:
if get_origin(cls) in UNION_TYPES:
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
return issubclass(cls, pydantic.BaseModel)
except TypeError:
return False
def is_collection_type(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin in UNION_TYPES:
for arg in get_args(annotation):
if is_collection_type(arg):
return True
return False
collection_types = (List, list, set, tuple)
if origin is None:
return (
isinstance(annotation, collection_types)
if not isinstance(annotation, type)
else issubclass(annotation, collection_types)
)
else:
return origin in collection_types # TODO: I guess we should handle only list
def detect_collection_fields(
args: List[FuncParam], flatten_map: Dict[str, Tuple[str, ...]]
) -> List[str]:
"""
Django QueryDict has values that are always lists, so we need to help django ninja to understand
better the input parameters if it's a list or a single value
This method detects attributes that should be treated by ninja as lists and returns this list as a result
"""
result = [i.alias or i.name for i in args if i.is_collection]
if flatten_map:
args_d = {arg.alias: arg for arg in args}
for path in (p for p in flatten_map.values() if len(p) > 1):
annotation_or_field: Any = args_d[path[0]].annotation
for attr in path[1:]:
if hasattr(annotation_or_field, "annotation"):
annotation_or_field = annotation_or_field.annotation
annotation_or_field = next(
(
a
for a in annotation_or_field.model_fields.values()
if a.alias == attr
),
annotation_or_field.model_fields.get(attr),
) # pragma: no cover
annotation_or_field = getattr(
annotation_or_field, "outer_type_", annotation_or_field
)
# if hasattr(annotation_or_field, "annotation"):
annotation_or_field = annotation_or_field.annotation
if is_collection_type(annotation_or_field):
result.append(path[-1])
return result