343 lines
10 KiB
Python
343 lines
10 KiB
Python
import inspect
|
|
from abc import ABC, abstractmethod
|
|
from functools import partial, wraps
|
|
from math import inf
|
|
from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple, Type, Union
|
|
|
|
from django.db.models import QuerySet
|
|
from django.http import HttpRequest
|
|
from django.utils.module_loading import import_string
|
|
from typing_extensions import get_args as get_collection_args
|
|
|
|
from ninja import Field, Query, Router, Schema
|
|
from ninja.conf import settings
|
|
from ninja.constants import NOT_SET
|
|
from ninja.errors import ConfigError
|
|
from ninja.operation import Operation
|
|
from ninja.signature.details import is_collection_type
|
|
from ninja.utils import (
|
|
contribute_operation_args,
|
|
contribute_operation_callback,
|
|
is_async_callable,
|
|
)
|
|
|
|
|
|
class PaginationBase(ABC):
|
|
class Input(Schema):
|
|
pass
|
|
|
|
InputSource = Query(...)
|
|
|
|
class Output(Schema):
|
|
items: List[Any]
|
|
count: int
|
|
|
|
items_attribute: str = "items"
|
|
|
|
def __init__(self, *, pass_parameter: Optional[str] = None, **kwargs: Any) -> None:
|
|
self.pass_parameter = pass_parameter
|
|
|
|
@abstractmethod
|
|
def paginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Any,
|
|
**params: Any,
|
|
) -> Any:
|
|
pass # pragma: no cover
|
|
|
|
def _items_count(self, queryset: QuerySet) -> int:
|
|
"""
|
|
Since lists are mainly compatible with QuerySets and can be passed to paginator.
|
|
We will first to try to use .count - and if not there will use a len
|
|
"""
|
|
try:
|
|
# forcing to find queryset.count instead of list.count:
|
|
return queryset.all().count()
|
|
except AttributeError:
|
|
return len(queryset)
|
|
|
|
|
|
class AsyncPaginationBase(PaginationBase):
|
|
@abstractmethod
|
|
async def apaginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Any,
|
|
**params: Any,
|
|
) -> Any:
|
|
pass # pragma: no cover
|
|
|
|
async def _aitems_count(self, queryset: QuerySet) -> int:
|
|
try:
|
|
return await queryset.all().acount()
|
|
except AttributeError:
|
|
return len(queryset)
|
|
|
|
|
|
class LimitOffsetPagination(AsyncPaginationBase):
|
|
class Input(Schema):
|
|
limit: int = Field(
|
|
settings.PAGINATION_PER_PAGE,
|
|
ge=1,
|
|
le=(
|
|
settings.PAGINATION_MAX_LIMIT
|
|
if settings.PAGINATION_MAX_LIMIT != inf
|
|
else None
|
|
),
|
|
)
|
|
offset: int = Field(0, ge=0)
|
|
|
|
def paginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Input,
|
|
**params: Any,
|
|
) -> Any:
|
|
offset = pagination.offset
|
|
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
|
|
return {
|
|
"items": queryset[offset : offset + limit],
|
|
"count": self._items_count(queryset),
|
|
} # noqa: E203
|
|
|
|
async def apaginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Input,
|
|
**params: Any,
|
|
) -> Any:
|
|
offset = pagination.offset
|
|
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
|
|
if isinstance(queryset, QuerySet):
|
|
items = [obj async for obj in queryset[offset : offset + limit]]
|
|
else:
|
|
items = queryset[offset : offset + limit]
|
|
return {
|
|
"items": items,
|
|
"count": await self._aitems_count(queryset),
|
|
} # noqa: E203
|
|
|
|
|
|
class PageNumberPagination(AsyncPaginationBase):
|
|
class Input(Schema):
|
|
page: int = Field(1, ge=1)
|
|
page_size: Optional[int] = Field(None, ge=1)
|
|
|
|
def __init__(
|
|
self,
|
|
page_size: int = settings.PAGINATION_PER_PAGE,
|
|
max_page_size: int = settings.PAGINATION_MAX_PER_PAGE_SIZE,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.page_size = page_size
|
|
self.max_page_size = max_page_size
|
|
super().__init__(**kwargs)
|
|
|
|
def _get_page_size(self, requested_page_size: Optional[int]) -> int:
|
|
if requested_page_size is None:
|
|
return self.page_size
|
|
|
|
return min(requested_page_size, self.max_page_size)
|
|
|
|
def paginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Input,
|
|
**params: Any,
|
|
) -> Any:
|
|
page_size = self._get_page_size(pagination.page_size)
|
|
offset = (pagination.page - 1) * page_size
|
|
return {
|
|
"items": queryset[offset : offset + page_size],
|
|
"count": self._items_count(queryset),
|
|
} # noqa: E203
|
|
|
|
async def apaginate_queryset(
|
|
self,
|
|
queryset: QuerySet,
|
|
pagination: Input,
|
|
**params: Any,
|
|
) -> Any:
|
|
page_size = self._get_page_size(pagination.page_size)
|
|
offset = (pagination.page - 1) * page_size
|
|
|
|
if isinstance(queryset, QuerySet):
|
|
items = [obj async for obj in queryset[offset : offset + page_size]]
|
|
else:
|
|
items = queryset[offset : offset + page_size]
|
|
|
|
return {
|
|
"items": items,
|
|
"count": await self._aitems_count(queryset),
|
|
} # noqa: E203
|
|
|
|
|
|
def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
|
|
"""
|
|
@api.get(...
|
|
@paginate
|
|
def my_view(request):
|
|
|
|
or
|
|
|
|
@api.get(...
|
|
@paginate(PageNumberPagination)
|
|
def my_view(request):
|
|
|
|
"""
|
|
|
|
isfunction = inspect.isfunction(func_or_pgn_class)
|
|
isnotset = func_or_pgn_class == NOT_SET
|
|
|
|
pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string(
|
|
settings.PAGINATION_CLASS
|
|
)
|
|
|
|
if isfunction:
|
|
return _inject_pagination(func_or_pgn_class, pagination_class)
|
|
|
|
if not isnotset:
|
|
pagination_class = func_or_pgn_class
|
|
|
|
def wrapper(func: Callable) -> Any:
|
|
return _inject_pagination(func, pagination_class, **paginator_params)
|
|
|
|
return wrapper
|
|
|
|
|
|
def _inject_pagination(
|
|
func: Callable,
|
|
paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
|
|
**paginator_params: Any,
|
|
) -> Callable:
|
|
paginator = paginator_class(**paginator_params)
|
|
if is_async_callable(func):
|
|
if not hasattr(paginator, "apaginate_queryset"):
|
|
raise ConfigError("Pagination class not configured for async requests")
|
|
|
|
@wraps(func)
|
|
async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
|
|
pagination_params = kwargs.pop("ninja_pagination")
|
|
if paginator.pass_parameter:
|
|
kwargs[paginator.pass_parameter] = pagination_params
|
|
|
|
items = await func(request, **kwargs)
|
|
|
|
result = await paginator.apaginate_queryset(
|
|
items, pagination=pagination_params, request=request, **kwargs
|
|
)
|
|
|
|
async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator:
|
|
for result in results:
|
|
yield result
|
|
|
|
if paginator.Output: # type: ignore
|
|
result[paginator.items_attribute] = [
|
|
result
|
|
async for result in evaluate(result[paginator.items_attribute])
|
|
]
|
|
return result
|
|
|
|
else:
|
|
|
|
@wraps(func)
|
|
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
|
|
pagination_params = kwargs.pop("ninja_pagination")
|
|
if paginator.pass_parameter:
|
|
kwargs[paginator.pass_parameter] = pagination_params
|
|
|
|
items = func(request, **kwargs)
|
|
|
|
result = paginator.paginate_queryset(
|
|
items, pagination=pagination_params, request=request, **kwargs
|
|
)
|
|
if paginator.Output: # type: ignore
|
|
result[paginator.items_attribute] = list(
|
|
result[paginator.items_attribute]
|
|
)
|
|
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
|
|
return result
|
|
|
|
contribute_operation_args(
|
|
view_with_pagination,
|
|
"ninja_pagination",
|
|
paginator.Input,
|
|
paginator.InputSource,
|
|
)
|
|
|
|
if paginator.Output: # type: ignore
|
|
contribute_operation_callback(
|
|
view_with_pagination,
|
|
partial(make_response_paginated, paginator),
|
|
)
|
|
|
|
return view_with_pagination
|
|
|
|
|
|
class RouterPaginated(Router):
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.pagination_class = import_string(settings.PAGINATION_CLASS)
|
|
|
|
def add_api_operation(
|
|
self, path: str, methods: List[str], view_func: Callable, **kwargs: Any
|
|
) -> None:
|
|
response = kwargs["response"]
|
|
if is_collection_type(response):
|
|
view_func = _inject_pagination(view_func, self.pagination_class)
|
|
return super().add_api_operation(path, methods, view_func, **kwargs)
|
|
|
|
|
|
def make_response_paginated(paginator: PaginationBase, op: Operation) -> None:
|
|
"""
|
|
Takes operation response and changes it to the paginated response
|
|
for example:
|
|
response=List[Some]
|
|
will be changed to:
|
|
response=PagedSome
|
|
where Paged some will be a subclass of paginator.Output:
|
|
class PagedSome:
|
|
items: List[Some]
|
|
count: int
|
|
"""
|
|
status_code, item_schema = _find_collection_response(op)
|
|
|
|
# Switching schema to Output schema
|
|
try:
|
|
new_name = f"Paged{item_schema.__name__}"
|
|
except AttributeError: # pragma: no cover
|
|
# special case for `typing.Any`, only raised for Python < 3.10
|
|
new_name = f"Paged{str(item_schema).replace('.', '_')}" # pragma: no cover
|
|
new_schema = type(
|
|
new_name,
|
|
(paginator.Output,),
|
|
{
|
|
"__annotations__": {paginator.items_attribute: List[item_schema]}, # type: ignore
|
|
},
|
|
) # typing: ignore
|
|
|
|
response = op._create_response_model(new_schema)
|
|
|
|
# Changing response model to newly created one
|
|
op.response_models[status_code] = response
|
|
|
|
|
|
def _find_collection_response(op: Operation) -> Tuple[int, Any]:
|
|
"""
|
|
Walks through defined operation responses and finds the first
|
|
that is of a collection type (e.g. List[SomeSchema])
|
|
"""
|
|
for code, resp_model in op.response_models.items():
|
|
if resp_model is None or resp_model is NOT_SET:
|
|
continue
|
|
|
|
model = resp_model.__annotations__["response"]
|
|
if is_collection_type(model):
|
|
item_schema = get_collection_args(model)[0]
|
|
return code, item_schema
|
|
|
|
raise ConfigError(
|
|
f'"{op.view_func}" has no collection response (e.g. response=List[SomeSchema])'
|
|
)
|