389 lines
14 KiB
Python

import itertools
import re
from http.client import responses
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Tuple
from django.utils.termcolors import make_style
from ninja.constants import NOT_SET
from ninja.operation import Operation
from ninja.params.models import TModel, TModels
from ninja.schema import NinjaGenerateJsonSchema
from ninja.types import DictStrAny
from ninja.utils import normalize_path
if TYPE_CHECKING:
from ninja import NinjaAPI # pragma: no cover
REF_TEMPLATE: str = "#/components/schemas/{model}"
BODY_CONTENT_TYPES: Dict[str, str] = {
"body": "application/json",
"form": "application/x-www-form-urlencoded",
"file": "multipart/form-data",
}
def get_schema(api: "NinjaAPI", path_prefix: str = "") -> "OpenAPISchema":
openapi = OpenAPISchema(api, path_prefix)
return openapi
bold_red_style = make_style(opts=("bold",), fg="red")
class OpenAPISchema(dict):
def __init__(self, api: "NinjaAPI", path_prefix: str) -> None:
self.api = api
self.path_prefix = path_prefix
self.schemas: DictStrAny = {}
self.securitySchemes: DictStrAny = {}
self.all_operation_ids: Set = set()
extra_info = api.openapi_extra.get("info", {})
super().__init__([
("openapi", "3.1.0"),
(
"info",
{
"title": api.title,
"version": api.version,
"description": api.description,
**extra_info,
},
),
("paths", self.get_paths()),
("components", self.get_components()),
("servers", api.servers),
])
for k, v in api.openapi_extra.items():
if k not in self:
self[k] = v
def get_paths(self) -> DictStrAny:
result: DictStrAny = {}
for prefix, router in self.api._routers:
for path, path_view in router.path_operations.items():
full_path = "/".join([i for i in (prefix, path) if i])
full_path = "/" + self.path_prefix + full_path
full_path = normalize_path(full_path)
full_path = re.sub(
r"{[^}:]+:", "{", full_path
) # remove path converters
path_methods = self.methods(path_view.operations)
if path_methods:
try:
result[full_path].update(path_methods)
except KeyError:
result[full_path] = path_methods
return result
def methods(self, operations: list) -> DictStrAny:
result = {}
for op in operations:
if op.include_in_schema:
operation_details = self.operation_details(op)
for method in op.methods:
result[method.lower()] = operation_details
return result
def deep_dict_update(
self, main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]
) -> None:
for key in update_dict:
if (
key in main_dict
and isinstance(main_dict[key], dict)
and isinstance(update_dict[key], dict)
):
self.deep_dict_update(
main_dict[key], update_dict[key]
) # pragma: no cover
else:
main_dict[key] = update_dict[key]
def operation_details(self, operation: Operation) -> DictStrAny:
op_id = operation.operation_id or self.api.get_openapi_operation_id(operation)
if op_id in self.all_operation_ids:
print(
bold_red_style(
f'Warning: operation_id "{op_id}" is already used (Try giving a different name to: {operation.view_func.__module__}.{operation.view_func.__name__})'
)
)
self.all_operation_ids.add(op_id)
result = {
"operationId": op_id,
"summary": operation.summary,
"parameters": self.operation_parameters(operation),
"responses": self.responses(operation),
}
if operation.description:
result["description"] = operation.description
if operation.tags:
result["tags"] = operation.tags
if operation.deprecated:
result["deprecated"] = operation.deprecated # type: ignore
body = self.request_body(operation)
if body:
result["requestBody"] = body
security = self.operation_security(operation)
if security:
result["security"] = security
if operation.openapi_extra:
self.deep_dict_update(result, operation.openapi_extra)
return result
def operation_parameters(self, operation: Operation) -> List[DictStrAny]:
result = []
for model in operation.models:
if model.__ninja_param_source__ not in BODY_CONTENT_TYPES:
result.extend(self._extract_parameters(model))
return result
def _extract_parameters(self, model: TModel) -> List[DictStrAny]:
result = []
schema = model.model_json_schema(
ref_template=REF_TEMPLATE,
schema_generator=NinjaGenerateJsonSchema,
)
required = set(schema.get("required", []))
properties = schema["properties"]
if "$defs" in schema:
self.add_schema_definitions(schema["$defs"])
for name, details in properties.items():
is_required = name in required
p_name: str
p_schema: DictStrAny
p_required: bool
for p_name, p_schema, p_required in flatten_properties(
name, details, is_required, schema.get("$defs", {})
):
if not p_schema.get("include_in_schema", True):
continue
param = {
"in": model.__ninja_param_source__,
"name": p_name,
"schema": p_schema,
"required": p_required,
}
# copy description from schema description to param description
if "description" in p_schema:
param["description"] = p_schema["description"]
if "examples" in p_schema:
param["examples"] = p_schema["examples"]
elif "example" in p_schema:
param["example"] = p_schema["example"]
if "deprecated" in p_schema:
param["deprecated"] = p_schema["deprecated"]
result.append(param)
return result
def _flatten_schema(self, model: TModel) -> DictStrAny:
params = self._extract_parameters(model)
flattened = {
"title": model.__name__, # type: ignore
"type": "object",
"properties": {p["name"]: p["schema"] for p in params},
}
required = [p["name"] for p in params if p["required"]]
if required:
flattened["required"] = required
return flattened
def _create_schema_from_model(
self,
model: TModel,
by_alias: bool = True,
remove_level: bool = True,
) -> Tuple[DictStrAny, bool]:
if hasattr(model, "__ninja_flatten_map__"):
schema = self._flatten_schema(model)
else:
schema = model.model_json_schema(
ref_template=REF_TEMPLATE,
by_alias=by_alias,
schema_generator=NinjaGenerateJsonSchema,
).copy()
# move Schemas from definitions
if schema.get("$defs"):
self.add_schema_definitions(schema.pop("$defs"))
if remove_level and len(schema["properties"]) == 1:
name, details = list(schema["properties"].items())[0]
# ref = details["$ref"]
required = name in schema.get("required", {})
return details, required
else:
return schema, True
def _create_multipart_schema_from_models(
self, models: TModels
) -> Tuple[DictStrAny, str]:
# We have File and Form or Body, so we need to use multipart (File)
content_type = BODY_CONTENT_TYPES["file"]
# get the various schemas
result = merge_schemas([
self._create_schema_from_model(model, remove_level=False)[0]
for model in models
])
result["title"] = "MultiPartBodyParams"
return result, content_type
def request_body(self, operation: Operation) -> DictStrAny:
models = [
m
for m in operation.models
if m.__ninja_param_source__ in BODY_CONTENT_TYPES
]
if not models:
return {}
if len(models) == 1:
model = models[0]
content_type = BODY_CONTENT_TYPES[model.__ninja_param_source__]
schema, required = self._create_schema_from_model(
model, remove_level=model.__ninja_param_source__ == "body"
)
else:
schema, content_type = self._create_multipart_schema_from_models(models)
required = True
return {
"content": {content_type: {"schema": schema}},
"required": required,
}
def responses(self, operation: Operation) -> Dict[int, DictStrAny]:
assert bool(operation.response_models), f"{operation.response_models} empty"
result = {}
for status, model in operation.response_models.items():
if status == Ellipsis:
continue # it's not yet clear what it means if user wants to output any other code
description = responses.get(status, "Unknown Status Code")
details: Dict[int, Any] = {status: {"description": description}}
if model not in [None, NOT_SET]:
# ::TODO:: test this: by_alias == True
schema = self._create_schema_from_model(
model, by_alias=operation.by_alias
)[0]
details[status]["content"] = {
self.api.renderer.media_type: {"schema": schema}
}
result.update(details)
return result
def operation_security(self, operation: Operation) -> Optional[List[DictStrAny]]:
if not operation.auth_callbacks:
return None
result = []
for auth in operation.auth_callbacks:
if hasattr(auth, "openapi_security_schema"):
scopes: List[DictStrAny] = [] # TODO: scopes
name = auth.__class__.__name__
result.append({name: scopes}) # TODO: check if unique
self.securitySchemes[name] = auth.openapi_security_schema
return result
def get_components(self) -> DictStrAny:
result = {"schemas": self.schemas}
if self.securitySchemes:
result["securitySchemes"] = self.securitySchemes
return result
def add_schema_definitions(self, definitions: dict) -> None:
# TODO: check if schema["definitions"] are unique
# if not - workaround (maybe use pydantic.schema.schema(models)) to process list of models
# assert set(definitions.keys()) - set(self.schemas.keys()) == set()
# ::TODO:: this is broken in interesting ways for by_alias,
# because same schema (name) can have different values
self.schemas.update(definitions)
def flatten_properties(
prop_name: str,
prop_details: DictStrAny,
prop_required: bool,
definitions: DictStrAny,
) -> Generator[Tuple[str, DictStrAny, bool], None, None]:
"""
extracts all nested model's properties into flat properties
(used f.e. in GET params with multiple arguments and models)
"""
if "allOf" in prop_details:
resolve_allOf(prop_details, definitions)
if len(prop_details["allOf"]) == 1 and "enum" in prop_details["allOf"][0]:
# is_required = "default" not in prop_details
yield prop_name, prop_details, prop_required
else: # pragma: no cover
# TODO: this code was for pydanitc 1.7+ ... <2.9 - check if this is still needed
for item in prop_details["allOf"]:
yield from flatten_properties("", item, True, definitions)
elif "items" in prop_details and "$ref" in prop_details["items"]:
def_name = prop_details["items"]["$ref"].rsplit("/", 1)[-1]
prop_details["items"].update(definitions[def_name])
del prop_details["items"]["$ref"] # seems num data is there so ref not needed
yield prop_name, prop_details, prop_required
elif "$ref" in prop_details:
def_name = prop_details["$ref"].split("/")[-1]
definition = definitions[def_name]
yield from flatten_properties(prop_name, definition, prop_required, definitions)
elif "properties" in prop_details:
required = set(prop_details.get("required", []))
for k, v in prop_details["properties"].items():
is_required = k in required
yield from flatten_properties(k, v, is_required, definitions)
else:
yield prop_name, prop_details, prop_required
def resolve_allOf(details: DictStrAny, definitions: DictStrAny) -> None:
"""
resolves all $ref's in 'allOf' section
"""
for item in details["allOf"]:
if "$ref" in item:
def_name = item["$ref"].rsplit("/", 1)[-1]
item.update(definitions[def_name])
del item["$ref"]
def merge_schemas(schemas: List[DictStrAny]) -> DictStrAny:
result = schemas[0]
for scm in schemas[1:]:
result["properties"].update(scm["properties"])
required_list = result.get("required", [])
required_list.extend(
itertools.chain.from_iterable(
schema.get("required", ()) for schema in schemas[1:]
)
)
if required_list:
result["required"] = required_list
return result