297 lines
7.9 KiB
Python
297 lines
7.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Pattern,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from django.conf import settings
|
|
from django.http import HttpRequest
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import FieldInfo
|
|
|
|
from ninja.errors import HttpError
|
|
from ninja.types import DictStrAny
|
|
|
|
if TYPE_CHECKING:
|
|
from ninja import NinjaAPI # pragma: no cover
|
|
|
|
__all__ = [
|
|
"ParamModel",
|
|
"QueryModel",
|
|
"PathModel",
|
|
"HeaderModel",
|
|
"CookieModel",
|
|
"BodyModel",
|
|
"FormModel",
|
|
"FileModel",
|
|
]
|
|
|
|
TModel = TypeVar("TModel", bound="ParamModel")
|
|
TModels = List[TModel]
|
|
|
|
|
|
def NestedDict() -> DictStrAny:
|
|
return defaultdict(NestedDict)
|
|
|
|
|
|
class ParamModel(BaseModel, ABC):
|
|
__ninja_param_source__ = None
|
|
|
|
@classmethod
|
|
@abstractmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
pass # pragma: no cover
|
|
|
|
@classmethod
|
|
def resolve(
|
|
cls: Type[TModel],
|
|
request: HttpRequest,
|
|
api: "NinjaAPI",
|
|
path_params: DictStrAny,
|
|
) -> TModel:
|
|
data = cls.get_request_data(request, api, path_params)
|
|
if data is None:
|
|
return cls()
|
|
|
|
data = cls._map_data_paths(data)
|
|
return cls.model_validate(data, context={"request": request})
|
|
|
|
@classmethod
|
|
def _map_data_paths(cls, data: DictStrAny) -> DictStrAny:
|
|
flatten_map = getattr(cls, "__ninja_flatten_map__", None)
|
|
if not flatten_map:
|
|
return data
|
|
|
|
mapped_data: DictStrAny = NestedDict()
|
|
for k in flatten_map:
|
|
if k in data:
|
|
cls._map_data_path(mapped_data, data[k], flatten_map[k])
|
|
else:
|
|
cls._map_data_path(mapped_data, None, flatten_map[k])
|
|
|
|
return mapped_data
|
|
|
|
@classmethod
|
|
def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None:
|
|
if len(path) == 1:
|
|
if value is not None:
|
|
data[path[0]] = value
|
|
else:
|
|
cls._map_data_path(data[path[0]], value, path[1:])
|
|
|
|
|
|
class QueryModel(ParamModel):
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
list_fields = getattr(cls, "__ninja_collection_fields__", [])
|
|
return api.parser.parse_querydict(request.GET, list_fields, request)
|
|
|
|
|
|
class PathModel(ParamModel):
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
return path_params
|
|
|
|
|
|
class HeaderModel(ParamModel):
|
|
__ninja_flatten_map__: DictStrAny
|
|
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
data = {}
|
|
headers = request.headers
|
|
for name in cls.__ninja_flatten_map__:
|
|
if name in headers:
|
|
data[name] = headers[name]
|
|
return data
|
|
|
|
|
|
class CookieModel(ParamModel):
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
return request.COOKIES
|
|
|
|
|
|
class BodyModel(ParamModel):
|
|
__read_from_single_attr__: str
|
|
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
if request.body:
|
|
try:
|
|
data = api.parser.parse_body(request)
|
|
except Exception as e:
|
|
msg = "Cannot parse request body"
|
|
if settings.DEBUG:
|
|
msg += f" ({e})"
|
|
raise HttpError(400, msg) from e
|
|
|
|
varname = getattr(cls, "__read_from_single_attr__", None)
|
|
if varname:
|
|
data = {varname: data}
|
|
return data
|
|
|
|
return None
|
|
|
|
|
|
class FormModel(ParamModel):
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
list_fields = getattr(cls, "__ninja_collection_fields__", [])
|
|
return api.parser.parse_querydict(request.POST, list_fields, request)
|
|
|
|
|
|
class FileModel(ParamModel):
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
list_fields = getattr(cls, "__ninja_collection_fields__", [])
|
|
return api.parser.parse_querydict(request.FILES, list_fields, request)
|
|
|
|
|
|
class _HttpRequest(HttpRequest):
|
|
body: bytes = b""
|
|
|
|
|
|
class _MultiPartBodyModel(BodyModel):
|
|
__ninja_body_params__: DictStrAny
|
|
|
|
@classmethod
|
|
def get_request_data(
|
|
cls, request: HttpRequest, api: "NinjaAPI", path_params: DictStrAny
|
|
) -> Optional[DictStrAny]:
|
|
req = _HttpRequest()
|
|
get_request_data = super().get_request_data
|
|
results: DictStrAny = {}
|
|
for name, annotation in cls.__ninja_body_params__.items():
|
|
if name in request.POST:
|
|
data = request.POST[name]
|
|
if annotation is str and data[0] != '"' and data[-1] != '"':
|
|
data = f'"{data}"'
|
|
req.body = data.encode()
|
|
results[name] = get_request_data(req, api, path_params)
|
|
return results
|
|
|
|
|
|
class Param(FieldInfo):
|
|
def __init__(
|
|
self,
|
|
default: Any,
|
|
*,
|
|
alias: Optional[str] = None,
|
|
title: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
gt: Optional[float] = None,
|
|
ge: Optional[float] = None,
|
|
lt: Optional[float] = None,
|
|
le: Optional[float] = None,
|
|
min_length: Optional[int] = None,
|
|
max_length: Optional[int] = None,
|
|
example: Optional[Any] = None,
|
|
examples: Optional[Dict[str, Any]] = None,
|
|
deprecated: Optional[bool] = None,
|
|
include_in_schema: Optional[bool] = True,
|
|
pattern: Union[str, Pattern[str], None] = None,
|
|
# param_name: str = None,
|
|
# param_type: Any = None,
|
|
**extra: Any,
|
|
):
|
|
self.deprecated = deprecated
|
|
# self.param_name: str = None
|
|
# self.param_type: Any = None
|
|
self.model_field: Optional[FieldInfo] = None
|
|
json_schema_extra = {}
|
|
if example:
|
|
json_schema_extra["example"] = example
|
|
if examples:
|
|
json_schema_extra["examples"] = examples
|
|
if deprecated:
|
|
json_schema_extra["deprecated"] = deprecated
|
|
if not include_in_schema:
|
|
json_schema_extra["include_in_schema"] = include_in_schema
|
|
if alias and not extra.get("validation_alias"):
|
|
extra["validation_alias"] = alias
|
|
if alias and not extra.get("serialization_alias"):
|
|
extra["serialization_alias"] = alias
|
|
|
|
super().__init__(
|
|
default=default,
|
|
alias=alias,
|
|
title=title,
|
|
description=description,
|
|
gt=gt,
|
|
ge=ge,
|
|
lt=lt,
|
|
le=le,
|
|
min_length=min_length,
|
|
max_length=max_length,
|
|
pattern=pattern,
|
|
json_schema_extra=json_schema_extra,
|
|
**extra,
|
|
)
|
|
|
|
@classmethod
|
|
def _param_source(cls) -> str:
|
|
"Openapi param.in value or body type"
|
|
return cls.__name__.lower()
|
|
|
|
|
|
class Path(Param):
|
|
_model = PathModel
|
|
|
|
|
|
class Query(Param):
|
|
_model = QueryModel
|
|
|
|
|
|
class Header(Param):
|
|
_model = HeaderModel
|
|
|
|
|
|
class Cookie(Param):
|
|
_model = CookieModel
|
|
|
|
|
|
class Body(Param):
|
|
_model = BodyModel
|
|
|
|
|
|
class Form(Param):
|
|
_model = FormModel
|
|
|
|
|
|
class File(Param):
|
|
_model = FileModel
|
|
|
|
|
|
class _MultiPartBody(Param):
|
|
_model = _MultiPartBodyModel
|
|
|
|
@classmethod
|
|
def _param_source(cls) -> str:
|
|
return "body"
|