import inspect from abc import ABC, abstractmethod from functools import partial, wraps from math import inf from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple, Type, Union from django.db.models import QuerySet from django.http import HttpRequest from django.utils.module_loading import import_string from typing_extensions import get_args as get_collection_args from ninja import Field, Query, Router, Schema from ninja.conf import settings from ninja.constants import NOT_SET from ninja.errors import ConfigError from ninja.operation import Operation from ninja.signature.details import is_collection_type from ninja.utils import ( contribute_operation_args, contribute_operation_callback, is_async_callable, ) class PaginationBase(ABC): class Input(Schema): pass InputSource = Query(...) class Output(Schema): items: List[Any] count: int items_attribute: str = "items" def __init__(self, *, pass_parameter: Optional[str] = None, **kwargs: Any) -> None: self.pass_parameter = pass_parameter @abstractmethod def paginate_queryset( self, queryset: QuerySet, pagination: Any, **params: Any, ) -> Any: pass # pragma: no cover def _items_count(self, queryset: QuerySet) -> int: """ Since lists are mainly compatible with QuerySets and can be passed to paginator. We will first to try to use .count - and if not there will use a len """ try: # forcing to find queryset.count instead of list.count: return queryset.all().count() except AttributeError: return len(queryset) class AsyncPaginationBase(PaginationBase): @abstractmethod async def apaginate_queryset( self, queryset: QuerySet, pagination: Any, **params: Any, ) -> Any: pass # pragma: no cover async def _aitems_count(self, queryset: QuerySet) -> int: try: return await queryset.all().acount() except AttributeError: return len(queryset) class LimitOffsetPagination(AsyncPaginationBase): class Input(Schema): limit: int = Field( settings.PAGINATION_PER_PAGE, ge=1, le=( settings.PAGINATION_MAX_LIMIT if settings.PAGINATION_MAX_LIMIT != inf else None ), ) offset: int = Field(0, ge=0) def paginate_queryset( self, queryset: QuerySet, pagination: Input, **params: Any, ) -> Any: offset = pagination.offset limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT) return { "items": queryset[offset : offset + limit], "count": self._items_count(queryset), } # noqa: E203 async def apaginate_queryset( self, queryset: QuerySet, pagination: Input, **params: Any, ) -> Any: offset = pagination.offset limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT) if isinstance(queryset, QuerySet): items = [obj async for obj in queryset[offset : offset + limit]] else: items = queryset[offset : offset + limit] return { "items": items, "count": await self._aitems_count(queryset), } # noqa: E203 class PageNumberPagination(AsyncPaginationBase): class Input(Schema): page: int = Field(1, ge=1) page_size: Optional[int] = Field(None, ge=1) def __init__( self, page_size: int = settings.PAGINATION_PER_PAGE, max_page_size: int = settings.PAGINATION_MAX_PER_PAGE_SIZE, **kwargs: Any, ) -> None: self.page_size = page_size self.max_page_size = max_page_size super().__init__(**kwargs) def _get_page_size(self, requested_page_size: Optional[int]) -> int: if requested_page_size is None: return self.page_size return min(requested_page_size, self.max_page_size) def paginate_queryset( self, queryset: QuerySet, pagination: Input, **params: Any, ) -> Any: page_size = self._get_page_size(pagination.page_size) offset = (pagination.page - 1) * page_size return { "items": queryset[offset : offset + page_size], "count": self._items_count(queryset), } # noqa: E203 async def apaginate_queryset( self, queryset: QuerySet, pagination: Input, **params: Any, ) -> Any: page_size = self._get_page_size(pagination.page_size) offset = (pagination.page - 1) * page_size if isinstance(queryset, QuerySet): items = [obj async for obj in queryset[offset : offset + page_size]] else: items = queryset[offset : offset + page_size] return { "items": items, "count": await self._aitems_count(queryset), } # noqa: E203 def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable: """ @api.get(... @paginate def my_view(request): or @api.get(... @paginate(PageNumberPagination) def my_view(request): """ isfunction = inspect.isfunction(func_or_pgn_class) isnotset = func_or_pgn_class == NOT_SET pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string( settings.PAGINATION_CLASS ) if isfunction: return _inject_pagination(func_or_pgn_class, pagination_class) if not isnotset: pagination_class = func_or_pgn_class def wrapper(func: Callable) -> Any: return _inject_pagination(func, pagination_class, **paginator_params) return wrapper def _inject_pagination( func: Callable, paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]], **paginator_params: Any, ) -> Callable: paginator = paginator_class(**paginator_params) if is_async_callable(func): if not hasattr(paginator, "apaginate_queryset"): raise ConfigError("Pagination class not configured for async requests") @wraps(func) async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: pagination_params = kwargs.pop("ninja_pagination") if paginator.pass_parameter: kwargs[paginator.pass_parameter] = pagination_params items = await func(request, **kwargs) result = await paginator.apaginate_queryset( items, pagination=pagination_params, request=request, **kwargs ) async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator: for result in results: yield result if paginator.Output: # type: ignore result[paginator.items_attribute] = [ result async for result in evaluate(result[paginator.items_attribute]) ] return result else: @wraps(func) def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: pagination_params = kwargs.pop("ninja_pagination") if paginator.pass_parameter: kwargs[paginator.pass_parameter] = pagination_params items = func(request, **kwargs) result = paginator.paginate_queryset( items, pagination=pagination_params, request=request, **kwargs ) if paginator.Output: # type: ignore result[paginator.items_attribute] = list( result[paginator.items_attribute] ) # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here return result contribute_operation_args( view_with_pagination, "ninja_pagination", paginator.Input, paginator.InputSource, ) if paginator.Output: # type: ignore contribute_operation_callback( view_with_pagination, partial(make_response_paginated, paginator), ) return view_with_pagination class RouterPaginated(Router): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.pagination_class = import_string(settings.PAGINATION_CLASS) def add_api_operation( self, path: str, methods: List[str], view_func: Callable, **kwargs: Any ) -> None: response = kwargs["response"] if is_collection_type(response): view_func = _inject_pagination(view_func, self.pagination_class) return super().add_api_operation(path, methods, view_func, **kwargs) def make_response_paginated(paginator: PaginationBase, op: Operation) -> None: """ Takes operation response and changes it to the paginated response for example: response=List[Some] will be changed to: response=PagedSome where Paged some will be a subclass of paginator.Output: class PagedSome: items: List[Some] count: int """ status_code, item_schema = _find_collection_response(op) # Switching schema to Output schema try: new_name = f"Paged{item_schema.__name__}" except AttributeError: # pragma: no cover # special case for `typing.Any`, only raised for Python < 3.10 new_name = f"Paged{str(item_schema).replace('.', '_')}" # pragma: no cover new_schema = type( new_name, (paginator.Output,), { "__annotations__": {paginator.items_attribute: List[item_schema]}, # type: ignore }, ) # typing: ignore response = op._create_response_model(new_schema) # Changing response model to newly created one op.response_models[status_code] = response def _find_collection_response(op: Operation) -> Tuple[int, Any]: """ Walks through defined operation responses and finds the first that is of a collection type (e.g. List[SomeSchema]) """ for code, resp_model in op.response_models.items(): if resp_model is None or resp_model is NOT_SET: continue model = resp_model.__annotations__["response"] if is_collection_type(model): item_schema = get_collection_args(model)[0] return code, item_schema raise ConfigError( f'"{op.view_func}" has no collection response (e.g. response=List[SomeSchema])' )