import inspect from typing import ( TYPE_CHECKING, Any, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Type, Union, cast, ) import pydantic from asgiref.sync import async_to_sync from django.http import HttpRequest, HttpResponse, HttpResponseNotAllowed from django.http.response import HttpResponseBase from ninja.compatibility.files import FIX_MIDDLEWARE_PATH, need_to_fix_request_files from ninja.constants import NOT_SET, NOT_SET_TYPE from ninja.errors import ( AuthenticationError, ConfigError, Throttled, ValidationErrorContext, ) from ninja.params.models import TModels from ninja.schema import Schema, pydantic_version from ninja.signature import ViewSignature, is_async from ninja.throttling import BaseThrottle from ninja.types import DictStrAny from ninja.utils import check_csrf, is_async_callable if TYPE_CHECKING: from ninja import NinjaAPI, Router # pragma: no cover __all__ = ["Operation", "PathView", "ResponseObject"] class Operation: def __init__( self, path: str, methods: List[str], view_func: Callable, *, auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, deprecated: Optional[bool] = None, by_alias: Optional[bool] = None, exclude_unset: Optional[bool] = None, exclude_defaults: Optional[bool] = None, exclude_none: Optional[bool] = None, include_in_schema: bool = True, url_name: Optional[str] = None, openapi_extra: Optional[Dict[str, Any]] = None, ) -> None: self.is_async = False self.path: str = path self.methods: List[str] = methods self.view_func: Callable = view_func self.api: NinjaAPI = cast("NinjaAPI", None) if url_name is not None: self.url_name = url_name self.auth_param: Optional[Union[Sequence[Callable], Callable, object]] = auth self.auth_callbacks: Sequence[Callable] = [] self._set_auth(auth) if isinstance(throttle, BaseThrottle): throttle = [throttle] self.throttle_param = throttle self.throttle_objects: List[BaseThrottle] = [] if throttle is not NOT_SET: for th in throttle: # type: ignore assert isinstance( th, BaseThrottle ), "Throttle should be an instance of BaseThrottle" self.throttle_objects.append(th) self.signature = ViewSignature(self.path, self.view_func) self.models: TModels = self.signature.models self.response_models: Dict[Any, Any] if response is NOT_SET: self.response_models = {200: NOT_SET} elif isinstance(response, dict): self.response_models = self._create_response_model_multiple(response) else: self.response_models = {200: self._create_response_model(response)} if need_to_fix_request_files(methods, self.models): raise ConfigError( f"Router '{path}' has method(s) {methods} that require fixing request.FILES. " f"Please add '{FIX_MIDDLEWARE_PATH}' to settings.MIDDLEWARE" ) self.operation_id = operation_id self.summary = summary or self.view_func.__name__.title().replace("_", " ") self.description = description or self.signature.docstring self.tags = tags self.deprecated = deprecated self.include_in_schema = include_in_schema self.openapi_extra = openapi_extra # Exporting models params self.by_alias = by_alias or False self.exclude_unset = exclude_unset or False self.exclude_defaults = exclude_defaults or False self.exclude_none = exclude_none or False if hasattr(view_func, "_ninja_contribute_to_operation"): # Allow 3rd party code to contribute to the operation behavior callbacks: List[Callable] = view_func._ninja_contribute_to_operation for callback in callbacks: callback(self) def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: error = self._run_checks(request) if error: return error try: temporal_response = self.api.create_temporal_response(request) values = self._get_values(request, kw, temporal_response) result = self.view_func(request, **values) return self._result_to_response(request, result, temporal_response) except Exception as e: if isinstance(e, TypeError) and "required positional argument" in str(e): msg = "Did you fail to use functools.wraps() in a decorator?" msg = f"{e.args[0]}: {msg}" if e.args else msg e.args = (msg,) + e.args[1:] return self.api.on_exception(request, e) def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.api = api if self.auth_param == NOT_SET: if router.auth != NOT_SET: # If the router auth was explicitly set, use it. self._set_auth(router.auth) elif api.auth != NOT_SET: # Otherwise fall back to the api auth. Since this is in an else branch, # it will only be used if the router auth was not explicitly set (i.e. # setting the router's auth to None explicitly allows "resetting" the # default auth that its operations will use). self._set_auth(self.api.auth) if self.throttle_param == NOT_SET: if api.throttle != NOT_SET: self.throttle_objects = ( isinstance(api.throttle, BaseThrottle) and [api.throttle] or api.throttle # type: ignore ) if router.throttle != NOT_SET: _t = router.throttle self.throttle_objects = isinstance(_t, BaseThrottle) and [_t] or _t # type: ignore assert all( isinstance(th, BaseThrottle) for th in self.throttle_objects ), "Throttle should be an instance of BaseThrottle" if self.tags is None: if router.tags is not None: self.tags = router.tags def _set_auth( self, auth: Optional[Union[Sequence[Callable], Callable, object]] ) -> None: if auth is not None and auth is not NOT_SET: self.auth_callbacks = isinstance(auth, Sequence) and auth or [auth] def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: "Runs security/throttle checks for each operation" # NOTE: if you change anything in this function - do this also in AsyncOperation # csrf: if self.api.csrf: error = check_csrf(request, self.view_func) if error: return error # auth: if self.auth_callbacks: error = self._run_authentication(request) # type: ignore if error: return error # Throttling: if self.throttle_objects: error = self._check_throttles(request) # type: ignore if error: return error return None def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: for callback in self.auth_callbacks: try: if is_async_callable(callback) or getattr(callback, "is_async", False): result = callback(request) if inspect.iscoroutine(result): result = async_to_sync(callback)(request) else: result = callback(request) except Exception as exc: return self.api.on_exception(request, exc) if result: request.auth = result # type: ignore return None return self.api.on_exception(request, AuthenticationError()) def _check_throttles(self, request: HttpRequest) -> Optional[HttpResponse]: throttle_durations = [] for throttle in self.throttle_objects: if not throttle.allow_request(request): throttle_durations.append(throttle.wait()) if throttle_durations: # Filter out `None` values which may happen in case of config / rate durations = [ duration for duration in throttle_durations if duration is not None ] duration = max(durations, default=None) return self.api.on_exception(request, Throttled(wait=duration)) # type: ignore return None def _result_to_response( self, request: HttpRequest, result: Any, temporal_response: HttpResponse ) -> HttpResponseBase: """ The protocol for results - if HttpResponse - returns as is - if tuple with 2 elements - means http_code + body - otherwise it's a body """ if isinstance(result, HttpResponseBase): return result status: int = 200 if len(self.response_models) == 1: status = next(iter(self.response_models)) if isinstance(result, tuple) and len(result) == 2: status = result[0] result = result[1] if status in self.response_models: response_model = self.response_models[status] elif Ellipsis in self.response_models: response_model = self.response_models[Ellipsis] else: raise ConfigError( f"Schema for status {status} is not set in response" f" {self.response_models.keys()}" ) temporal_response.status_code = status if response_model is NOT_SET: return self.api.create_response( request, result, temporal_response=temporal_response ) if response_model is None: # Empty response. return temporal_response resp_object = ResponseObject(result) # ^ we need object because getter_dict seems work only with model_validate validated_object = response_model.model_validate( resp_object, context={"request": request, "response_status": status} ) model_dump_kwargs: Dict[str, Any] = {} if pydantic_version >= [2, 7]: # pydantic added support for serialization context at 2.7 model_dump_kwargs.update( context={"request": request, "response_status": status} ) result = validated_object.model_dump( by_alias=self.by_alias, exclude_unset=self.exclude_unset, exclude_defaults=self.exclude_defaults, exclude_none=self.exclude_none, **model_dump_kwargs, )["response"] return self.api.create_response( request, result, temporal_response=temporal_response ) def _get_values( self, request: HttpRequest, path_params: Any, temporal_response: HttpResponse ) -> DictStrAny: values = {} error_contexts: List[ValidationErrorContext] = [] for model in self.models: try: data = model.resolve(request, self.api, path_params) values.update(data) except pydantic.ValidationError as e: error_contexts.append( ValidationErrorContext(pydantic_validation_error=e, model=model) ) if error_contexts: validation_error = self.api.validation_error_from_error_contexts( error_contexts ) raise validation_error if self.signature.response_arg: values[self.signature.response_arg] = temporal_response return values def _create_response_model_multiple( self, response_param: DictStrAny ) -> Dict[str, Optional[Type[Schema]]]: result = {} for key, model in response_param.items(): status_codes = isinstance(key, Iterable) and key or [key] for code in status_codes: result[code] = self._create_response_model(model) return result def _create_response_model(self, response_param: Any) -> Optional[Type[Schema]]: if response_param is None: return None attrs = {"__annotations__": {"response": response_param}} return type("NinjaResponseSchema", (Schema,), attrs) class AsyncOperation(Operation): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.is_async = True async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # type: ignore error = await self._run_checks(request) if error: return error try: temporal_response = self.api.create_temporal_response(request) values = self._get_values(request, kw, temporal_response) result = await self.view_func(request, **values) return self._result_to_response(request, result, temporal_response) except Exception as e: return self.api.on_exception(request, e) async def _run_checks(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore "Runs security checks for each operation" # NOTE: if you change anything in this function - do this also in Sync Operation # auth: if self.auth_callbacks: error = await self._run_authentication(request) if error: return error # csrf: if self.api.csrf: error = check_csrf(request, self.view_func) if error: return error # Throttling: if self.throttle_objects: error = self._check_throttles(request) if error: return error return None async def _run_authentication(self, request: HttpRequest) -> Optional[HttpResponse]: # type: ignore for callback in self.auth_callbacks: try: if is_async_callable(callback) or getattr(callback, "is_async", False): cor: Optional[Coroutine] = callback(request) if cor is None: result = None else: result = await cor else: result = callback(request) except Exception as exc: return self.api.on_exception(request, exc) if result: request.auth = result # type: ignore return None return self.api.on_exception(request, AuthenticationError()) class PathView: def __init__(self) -> None: self.operations: List[Operation] = [] self.is_async = False # if at least one operation is async - will become True self.url_name: Optional[str] = None def add_operation( self, path: str, methods: List[str], view_func: Callable, *, auth: Optional[Union[Sequence[Callable], Callable, NOT_SET_TYPE]] = NOT_SET, throttle: Union[BaseThrottle, List[BaseThrottle], NOT_SET_TYPE] = NOT_SET, response: Any = NOT_SET, operation_id: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, deprecated: Optional[bool] = None, by_alias: Optional[bool] = None, exclude_unset: Optional[bool] = None, exclude_defaults: Optional[bool] = None, exclude_none: Optional[bool] = None, url_name: Optional[str] = None, include_in_schema: bool = True, openapi_extra: Optional[Dict[str, Any]] = None, ) -> Operation: if url_name: self.url_name = url_name OperationClass = Operation if is_async(view_func): self.is_async = True OperationClass = AsyncOperation operation = OperationClass( path, methods, view_func, auth=auth, throttle=throttle, response=response, operation_id=operation_id, summary=summary, description=description, tags=tags, deprecated=deprecated, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, include_in_schema=include_in_schema, url_name=url_name, openapi_extra=openapi_extra, ) self.operations.append(operation) view_func._ninja_operation = operation # type: ignore return operation def set_api_instance(self, api: "NinjaAPI", router: "Router") -> None: self.api = api for op in self.operations: op.set_api_instance(api, router) def get_view(self) -> Callable: view: Callable if self.is_async: view = self._async_view else: view = self._sync_view view.__func__.csrf_exempt = True # type: ignore return view def _sync_view(self, request: HttpRequest, *a: Any, **kw: Any) -> HttpResponseBase: operation = self._find_operation(request) if operation is None: return self._not_allowed() return operation.run(request, *a, **kw) async def _async_view( self, request: HttpRequest, *a: Any, **kw: Any ) -> HttpResponseBase: from asgiref.sync import sync_to_async operation = self._find_operation(request) if operation is None: return self._not_allowed() if operation.is_async: return await cast(AsyncOperation, operation).run(request, *a, **kw) return await sync_to_async(operation.run)(request, *a, **kw) def _find_operation(self, request: HttpRequest) -> Optional[Operation]: for op in self.operations: if request.method in op.methods: return op return None def _not_allowed(self) -> HttpResponse: allowed_methods = set() for op in self.operations: allowed_methods.update(op.methods) return HttpResponseNotAllowed(allowed_methods, content=b"Method not allowed") class ResponseObject: "Basically this is just a helper to be able to pass response to pydantic's model_validate" def __init__(self, response: HttpResponse) -> None: self.response = response