514 lines
19 KiB
Python

import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
import pydantic
from asgiref.sync import async_to_sync
from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed
from django.http.response import HttpResponseBase
from ninja.compatibility.files import FIX_MIDDLEWARE_PATH, need_to_fix_request_files
from ninja.constants import NOT_SET, NOT_SET_TYPE
from ninja.errors import (
AuthenticationError,
ConfigError,
Throttled,
ValidationErrorContext,
)
from ninja.params.models import TModels
from ninja.schema import Schema, pydantic_version
from ninja.signature import ViewSignature, is_async
from ninja.throttling import BaseThrottle
from ninja.types import DictStrAny
from ninja.utils import check_csrf, is_async_callable
if TYPE_CHECKING:
from ninja import NinjaAPI, Router # pragma: no cover
__all__ = ["Operation", "PathView", "ResponseObject"]
class Operation:
def __init__(
self,
path: str,
methods: List[str],
view_func: Callable,
*,
auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET,
throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
response: Any = NOT_SET,
operation_id: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
deprecated: Optional[bool] = None,
by_alias: Optional[bool] = None,
exclude_unset: Optional[bool] = None,
exclude_defaults: Optional[bool] = None,
exclude_none: Optional[bool] = None,
include_in_schema: bool = True,
url_name: Optional[str] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
) -> None:
self.is_async = False
self.path: str = path
self.methods: List[str] = methods
self.view_func: Callable = view_func
self.api: NinjaAPI = cast("NinjaAPI", None)
if url_name is not None:
self.url_name = url_name
self.auth_param: Optional[Union[Sequence[Callable], Callable, object]] = auth
self.auth_callbacks: Sequence[Callable] = []
self._set_auth(auth)
if isinstance(throttle, BaseThrottle):
throttle = [throttle]
self.throttle_param = throttle
self.throttle_objects: List[BaseThrottle] = []
if throttle is not NOT_SET:
for th in throttle: # type: ignore
assert isinstance(
th, BaseThrottle
), "Throttle should be an instance of BaseThrottle"
self.throttle_objects.append(th)
self.signature = ViewSignature(self.path, self.view_func)
self.models: TModels = self.signature.models
self.response_models: Dict[Any, Any]
if response is NOT_SET:
self.response_models = {200: NOT_SET}
elif isinstance(response, dict):
self.response_models = self._create_response_model_multiple(response)
else:
self.response_models = {200: self._create_response_model(response)}
if need_to_fix_request_files(methods, self.models):
raise ConfigError(
f"Router '{path}' has method(s) {methods} that require fixing request.FILES. "
f"Please add '{FIX_MIDDLEWARE_PATH}' to settings.MIDDLEWARE"
)
self.operation_id = operation_id
self.summary = summary or self.view_func.__name__.title().replace("_", " ")
self.description = description or self.signature.docstring
self.tags = tags
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.openapi_extra = openapi_extra
# Exporting models params
self.by_alias = by_alias or False
self.exclude_unset = exclude_unset or False
self.exclude_defaults = exclude_defaults or False
self.exclude_none = exclude_none or False
if hasattr(view_func, "_ninja_contribute_to_operation"):
# Allow 3rd party code to contribute to the operation behavior
callbacks: List[Callable] = view_func._ninja_contribute_to_operation
for callback in callbacks:
callback(self)
def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
error = self._run_checks(request)
if error:
return error
try:
temporal_response = self.api.create_temporal_response(request)
values = self._get_values(request, kw, temporal_response)
result = self.view_func(request, **values)
return self._result_to_response(request, result, temporal_response)
except Exception as e:
if isinstance(e, TypeError) and "required positional argument" in str(e):
msg = "Did you fail to use functools.wraps() in a decorator?"
msg = f"{e.args[0]}: {msg}" if e.args else msg
e.args = (msg,) + e.args[1:]
return self.api.on_exception(request, e)
def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None:
self.api = api
if self.auth_param == NOT_SET:
if router.auth != NOT_SET:
# If the router auth was explicitly set, use it.
self._set_auth(router.auth)
elif api.auth != NOT_SET:
# Otherwise fall back to the api auth. Since this is in an else branch,
# it will only be used if the router auth was not explicitly set (i.e.
# setting the router's auth to None explicitly allows "resetting" the
# default auth that its operations will use).
self._set_auth(self.api.auth)
if self.throttle_param == NOT_SET:
if api.throttle != NOT_SET:
self.throttle_objects = (
isinstance(api.throttle, BaseThrottle)
and [api.throttle]
or api.throttle # type: ignore
)
if router.throttle != NOT_SET:
_t = router.throttle
self.throttle_objects = isinstance(_t, BaseThrottle) and [_t] or _t # type: ignore
assert all(
isinstance(th, BaseThrottle) for th in self.throttle_objects
), "Throttle should be an instance of BaseThrottle"
if self.tags is None:
if router.tags is not None:
self.tags = router.tags
def _set_auth(
self, auth: Optional[Union[Sequence[Callable], Callable, object]]
) -> None:
if auth is not None and auth is not NOT_SET:
self.auth_callbacks = isinstance(auth, Sequence) and auth or [auth]
def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]:
"Runs security/throttle checks for each operation"
# NOTE: if you change anything in this function - do this also in AsyncOperation
# csrf:
if self.api.csrf:
error = check_csrf(request, self.view_func)
if error:
return error
# auth:
if self.auth_callbacks:
error = self._run_authentication(request) # type: ignore
if error:
return error
# Throttling:
if self.throttle_objects:
error = self._check_throttles(request) # type: ignore
if error:
return error
return None
def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]:
for callback in self.auth_callbacks:
try:
if is_async_callable(callback) or getattr(callback, "is_async", False):
result = callback(request)
if inspect.iscoroutine(result):
result = async_to_sync(callback)(request)
else:
result = callback(request)
except Exception as exc:
return self.api.on_exception(request, exc)
if result:
request.auth = result # type: ignore
return None
return self.api.on_exception(request, AuthenticationError())
def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]:
throttle_durations = []
for throttle in self.throttle_objects:
if not throttle.allow_request(request):
throttle_durations.append(throttle.wait())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
durations = [
duration for duration in throttle_durations if duration is not None
]
duration = max(durations, default=None)
return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore
return None
def _result_to_response(
self, request: HttpRequest, result: Any, temporal_response: HttpResponse
) -> HttpResponseBase:
"""
The protocol for results
- if HttpResponse - returns as is
- if tuple with 2 elements - means http_code + body
- otherwise it's a body
"""
if isinstance(result, HttpResponseBase):
return result
status: int = 200
if len(self.response_models) == 1:
status = next(iter(self.response_models))
if isinstance(result, tuple) and len(result) == 2:
status = result[0]
result = result[1]
if status in self.response_models:
response_model = self.response_models[status]
elif Ellipsis in self.response_models:
response_model = self.response_models[Ellipsis]
else:
raise ConfigError(
f"Schema for status {status} is not set in response"
f" {self.response_models.keys()}"
)
temporal_response.status_code = status
if response_model is NOT_SET:
return self.api.create_response(
request, result, temporal_response=temporal_response
)
if response_model is None:
# Empty response.
return temporal_response
resp_object = ResponseObject(result)
# ^ we need object because getter_dict seems work only with model_validate
validated_object = response_model.model_validate(
resp_object, context={"request": request, "response_status": status}
)
model_dump_kwargs: Dict[str, Any] = {}
if pydantic_version >= [2, 7]:
# pydantic added support for serialization context at 2.7
model_dump_kwargs.update(
context={"request": request, "response_status": status}
)
result = validated_object.model_dump(
by_alias=self.by_alias,
exclude_unset=self.exclude_unset,
exclude_defaults=self.exclude_defaults,
exclude_none=self.exclude_none,
**model_dump_kwargs,
)["response"]
return self.api.create_response(
request, result, temporal_response=temporal_response
)
def _get_values(
self, request: HttpRequest, path_params: Any, temporal_response: HttpResponse
) -> DictStrAny:
values = {}
error_contexts: List[ValidationErrorContext] = []
for model in self.models:
try:
data = model.resolve(request, self.api, path_params)
values.update(data)
except pydantic.ValidationError as e:
error_contexts.append(
ValidationErrorContext(pydantic_validation_error=e, model=model)
)
if error_contexts:
validation_error = self.api.validation_error_from_error_contexts(
error_contexts
)
raise validation_error
if self.signature.response_arg:
values[self.signature.response_arg] = temporal_response
return values
def _create_response_model_multiple(
self, response_param: DictStrAny
) -> Dict[str, Optional[Type[Schema]]]:
result = {}
for key, model in response_param.items():
status_codes = isinstance(key, Iterable) and key or [key]
for code in status_codes:
result[code] = self._create_response_model(model)
return result
def _create_response_model(self, response_param: Any) -> Optional[Type[Schema]]:
if response_param is None:
return None
attrs = {"__annotations__": {"response": response_param}}
return type("NinjaResponseSchema", (Schema,), attrs)
class AsyncOperation(Operation):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.is_async = True
async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # type: ignore
error = await self._run_checks(request)
if error:
return error
try:
temporal_response = self.api.create_temporal_response(request)
values = self._get_values(request, kw, temporal_response)
result = await self.view_func(request, **values)
return self._result_to_response(request, result, temporal_response)
except Exception as e:
return self.api.on_exception(request, e)
async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
"Runs security checks for each operation"
# NOTE: if you change anything in this function - do this also in Sync Operation
# auth:
if self.auth_callbacks:
error = await self._run_authentication(request)
if error:
return error
# csrf:
if self.api.csrf:
error = check_csrf(request, self.view_func)
if error:
return error
# Throttling:
if self.throttle_objects:
error = self._check_throttles(request)
if error:
return error
return None
async def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore
for callback in self.auth_callbacks:
try:
if is_async_callable(callback) or getattr(callback, "is_async", False):
cor: Optional[Coroutine] = callback(request)
if cor is None:
result = None
else:
result = await cor
else:
result = callback(request)
except Exception as exc:
return self.api.on_exception(request, exc)
if result:
request.auth = result # type: ignore
return None
return self.api.on_exception(request, AuthenticationError())
class PathView:
def __init__(self) -> None:
self.operations: List[Operation] = []
self.is_async = False # if at least one operation is async - will become True
self.url_name: Optional[str] = None
def add_operation(
self,
path: str,
methods: List[str],
view_func: Callable,
*,
auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET,
throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET,
response: Any = NOT_SET,
operation_id: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
deprecated: Optional[bool] = None,
by_alias: Optional[bool] = None,
exclude_unset: Optional[bool] = None,
exclude_defaults: Optional[bool] = None,
exclude_none: Optional[bool] = None,
url_name: Optional[str] = None,
include_in_schema: bool = True,
openapi_extra: Optional[Dict[str, Any]] = None,
) -> Operation:
if url_name:
self.url_name = url_name
OperationClass = Operation
if is_async(view_func):
self.is_async = True
OperationClass = AsyncOperation
operation = OperationClass(
path,
methods,
view_func,
auth=auth,
throttle=throttle,
response=response,
operation_id=operation_id,
summary=summary,
description=description,
tags=tags,
deprecated=deprecated,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
include_in_schema=include_in_schema,
url_name=url_name,
openapi_extra=openapi_extra,
)
self.operations.append(operation)
view_func._ninja_operation = operation # type: ignore
return operation
def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None:
self.api = api
for op in self.operations:
op.set_api_instance(api, router)
def get_view(self) -> Callable:
view: Callable
if self.is_async:
view = self._async_view
else:
view = self._sync_view
view.__func__.csrf_exempt = True # type: ignore
return view
def _sync_view(self, request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase:
operation = self._find_operation(request)
if operation is None:
return self._not_allowed()
return operation.run(request, *a, **kw)
async def _async_view(
self, request: HttpRequest, *a: Any, **kw: Any
) -> HttpResponseBase:
from asgiref.sync import sync_to_async
operation = self._find_operation(request)
if operation is None:
return self._not_allowed()
if operation.is_async:
return await cast(AsyncOperation, operation).run(request, *a, **kw)
return await sync_to_async(operation.run)(request, *a, **kw)
def _find_operation(self, request: HttpRequest) -> Optional[Operation]:
for op in self.operations:
if request.method in op.methods:
return op
return None
def _not_allowed(self) -> HttpResponse:
allowed_methods = set()
for op in self.operations:
allowed_methods.update(op.methods)
return HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed")
class ResponseObject:
"Basically this is just a helper to be able to pass response to pydantic's model_validate"
def __init__(self, response: HttpResponse) -> None:
self.response = response