-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
104 lines (92 loc) · 3.21 KB
/
Copy pathapp.py
File metadata and controls
104 lines (92 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List, Dict, Optional
from src.rag import (
rag,
rag_stream,
rag_with_session,
rag_stream_with_session,
SessionState,
)
import logging
import os
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
)
logger = logging.getLogger(__name__)
app = FastAPI(title="DocuRAG")
class In(BaseModel):
text: str
session_id: Optional[str] = None
fireworks_api_key: str
solr_url: str
solr_username: str
solr_password: str
class Out(BaseModel):
answer: str
sources: List[Dict]
@app.post("/process", response_model=Out)
def process_endpoint(payload: In):
try:
if payload.session_id:
state = SESSIONS.setdefault(payload.session_id, SessionState())
logger.info(f"Session state: {state.last_query}")
result = rag_with_session(
payload.text,
state,
solr_url=payload.solr_url,
solr_username=payload.solr_username,
solr_password=payload.solr_password,
fireworks_api_key=payload.fireworks_api_key,
)
else:
result = rag(
payload.text,
solr_url=payload.solr_url,
solr_username=payload.solr_username,
solr_password=payload.solr_password,
fireworks_api_key=payload.fireworks_api_key,
)
return Out(answer=result["answer"], sources=result["sources"])
except Exception as e:
logger.exception("/process failed: %s", e)
# make failures predictable for clients
raise HTTPException(status_code=400, detail=str(e))
@app.post("/process_stream")
async def process_stream_endpoint(payload: In):
try:
if payload.session_id:
state = SESSIONS.setdefault(payload.session_id, SessionState())
logger.info(f"Session state: {state.last_query}")
gen = rag_stream_with_session(
payload.text,
state,
solr_url=payload.solr_url,
solr_username=payload.solr_username,
solr_password=payload.solr_password,
fireworks_api_key=payload.fireworks_api_key,
)
else:
gen = rag_stream(
payload.text,
solr_url=payload.solr_url,
solr_username=payload.solr_username,
solr_password=payload.solr_password,
fireworks_api_key=payload.fireworks_api_key,
)
return StreamingResponse(
gen,
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
except Exception as e:
logger.exception("/process_stream failed: %s", e)
raise HTTPException(status_code=400, detail=str(e))
# simple in-memory session storage (sufficient for single-process deployment)
SESSIONS: dict[str, SessionState] = {}