diff --git a/src/litserve/server.py b/src/litserve/server.py index 51774b76..eefb797c 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -32,7 +32,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from contextlib import asynccontextmanager from queue import Queue -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import uvicorn import uvicorn.server @@ -296,6 +296,9 @@ def __init__(self, lit_api: LitAPI, server: "LitServer"): async def _prepare_request(self, request, request_type) -> dict: """Common request preparation logic.""" + # FastAPI parses JSON body to dict when endpoint uses dict type annotation + if isinstance(request, dict): + return request if request_type == Request: content_type = request.headers.get("Content-Type", "") if content_type == "application/x-www-form-urlencoded" or content_type.startswith("multipart/form-data"): @@ -1084,18 +1087,21 @@ def register_endpoints(self): encode_response_signature = inspect.signature(lit_api.encode_response) request_type = decode_request_signature.parameters["request"].annotation - if request_type == decode_request_signature.empty: + # Distinguish missing annotation from explicit `Request` annotation. + # Missing: use dict[str, Any] for Swagger. Explicit: keep Request for form/file uploads. + request_type_is_default = request_type == decode_request_signature.empty + if request_type_is_default: request_type = Request response_type = encode_response_signature.return_annotation if response_type == encode_response_signature.empty: response_type = Response - self._register_api_endpoints(lit_api, request_type, response_type) + self._register_api_endpoints(lit_api, request_type, response_type, request_type_is_default) def _get_request_queue(self, api_path: str): return self.litapi_request_queues[api_path] - def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type): + def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type, request_type_is_default=False): """Register endpoint routes for the FastAPI app.""" self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self) @@ -1103,8 +1109,11 @@ def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type): # Create handlers handler = StreamingRequestHandler(lit_api, self) if lit_api.stream else RegularRequestHandler(lit_api, self) + # When no type annotation is provided, use dict[str, Any] so Swagger renders a request body form. + endpoint_request_type = dict[str, Any] if request_type_is_default else request_type + # Create endpoint function - async def endpoint_handler(request: request_type) -> response_type: + async def endpoint_handler(request: endpoint_request_type) -> response_type: # type: ignore[reportInvalidTypeForm] return await handler.handle_request(request, request_type) # Register endpoint diff --git a/tests/unit/test_pydantic.py b/tests/unit/test_pydantic.py index 397cb3c5..643e5c06 100644 --- a/tests/unit/test_pydantic.py +++ b/tests/unit/test_pydantic.py @@ -45,3 +45,19 @@ def test_pydantic(): with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 16.0} + + +class NoAnnotationLitAPI(LitAPI): + def setup(self, device): + pass + + def predict(self, request): + return {"output": request["input"] ** 2} + + +def test_swagger_request_body_without_annotation(): + server = LitServer(NoAnnotationLitAPI(), accelerator="cpu", devices=1, timeout=5) + schema = server.app.openapi() + predict_post = schema["paths"]["/predict"]["post"] + assert "requestBody" in predict_post, "Swagger must expose a requestBody for /predict" + assert "application/json" in predict_post["requestBody"]["content"]