Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from contextlib import asynccontextmanager
from queue import Queue
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import uvicorn
import uvicorn.server
Expand Down Expand Up @@ -296,6 +296,9 @@ def __init__(self, lit_api: LitAPI, server: "LitServer"):

async def _prepare_request(self, request, request_type) -> dict:
"""Common request preparation logic."""
# FastAPI parses JSON body to dict when endpoint uses dict type annotation
if isinstance(request, dict):
return request
if request_type == Request:
content_type = request.headers.get("Content-Type", "")
if content_type == "application/x-www-form-urlencoded" or content_type.startswith("multipart/form-data"):
Expand Down Expand Up @@ -1084,27 +1087,33 @@ def register_endpoints(self):
encode_response_signature = inspect.signature(lit_api.encode_response)

request_type = decode_request_signature.parameters["request"].annotation
if request_type == decode_request_signature.empty:
# Distinguish missing annotation from explicit `Request` annotation.
# Missing: use dict[str, Any] for Swagger. Explicit: keep Request for form/file uploads.
request_type_is_default = request_type == decode_request_signature.empty
if request_type_is_default:
request_type = Request

response_type = encode_response_signature.return_annotation
if response_type == encode_response_signature.empty:
response_type = Response
self._register_api_endpoints(lit_api, request_type, response_type)
self._register_api_endpoints(lit_api, request_type, response_type, request_type_is_default)

def _get_request_queue(self, api_path: str):
return self.litapi_request_queues[api_path]

def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type):
def _register_api_endpoints(self, lit_api: LitAPI, request_type, response_type, request_type_is_default=False):
"""Register endpoint routes for the FastAPI app."""

self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self)

# Create handlers
handler = StreamingRequestHandler(lit_api, self) if lit_api.stream else RegularRequestHandler(lit_api, self)

# When no type annotation is provided, use dict[str, Any] so Swagger renders a request body form.
endpoint_request_type = dict[str, Any] if request_type_is_default else request_type

# Create endpoint function
async def endpoint_handler(request: request_type) -> response_type:
async def endpoint_handler(request: endpoint_request_type) -> response_type: # type: ignore[reportInvalidTypeForm]
return await handler.handle_request(request, request_type)

# Register endpoint
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,19 @@ def test_pydantic():
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}


class NoAnnotationLitAPI(LitAPI):
def setup(self, device):
pass

def predict(self, request):
return {"output": request["input"] ** 2}


def test_swagger_request_body_without_annotation():
server = LitServer(NoAnnotationLitAPI(), accelerator="cpu", devices=1, timeout=5)
schema = server.app.openapi()
predict_post = schema["paths"]["/predict"]["post"]
assert "requestBody" in predict_post, "Swagger must expose a requestBody for /predict"
assert "application/json" in predict_post["requestBody"]["content"]
Loading