Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions ai-backend/backend/src/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
from datetime import datetime
from io import StringIO
from os import getenv
from typing import Annotated
from typing import Annotated, Union

from backend.src.database import SessionLocal, get_db
from backend.src.ir_pipeline.orchestrator import search, search_playground, search_rag
from backend.src.ir_pipeline.orchestrator import (
search,
search_playground,
search_rag,
search_rag_paper,
)
from backend.src.ir_pipeline.schema import Terms
from backend.src.ir_pipeline.tools.inspire import InspireOSFullTextSearchTool
from backend.src.models import Feedback, QueryIr, SearchFeedback
from backend.src.schemas.feedback import FeedbackRequest
from backend.src.schemas.query import QueryRequest, QueryResponse
from backend.src.schemas.query import QueryPaperResponse, QueryRequest, QueryResponse
from backend.src.schemas.search_feedback import SearchFeedbackRequest
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -278,15 +283,35 @@ async def export_feedback(
return feedbacks


@router.post("/query-rag", response_model=QueryResponse)
async def query_rag(request: QueryRequest):
@router.post(
"/query-rag",
response_model=Union[QueryResponse, QueryPaperResponse],
)
async def query_rag(request: QueryRequest) -> Union[QueryResponse, QueryPaperResponse]:
"""
Process a query using the RAG pipeline and return the response with citations.
"""
try:
logger.info("[query_rag] Received RAG query: %s", request.query)
start = time.time()
response = await search_rag(request.query, request.model, request.user)

if request.control_number is not None:
chat_history = (
[msg.model_dump() for msg in request.history]
if request.history
else None
)

response = await search_rag_paper(
request.query,
request.model,
request.control_number,
request.user,
chat_history,
)
else:
response = await search_rag(request.query, request.model, request.user)

end = time.time()
logger.info("[query_rag] RAG query processed in %.2fs", end - start)
return response
Expand Down
12 changes: 11 additions & 1 deletion ai-backend/backend/src/ir_pipeline/chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from backend.src.ir_pipeline.schema import LLMResponse, Terms
from backend.src.ir_pipeline.schema import LLMPaperResponse, LLMResponse, Terms
from backend.src.utils.langfuse import get_prompt
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import PydanticOutputParser
Expand Down Expand Up @@ -35,3 +35,13 @@ def create_rag_answer_generation_chain(llm: BaseLanguageModel):
)
chain = prompt_template | llm | output_parser
return chain.with_config(config)


def create_rag_paper_answer_generation_chain(llm: BaseLanguageModel):
prompt_template, langfuse_prompt = get_prompt("rag-paper-query")
output_parser = PydanticOutputParser(pydantic_object=LLMPaperResponse)
config = RunnableConfig(
run_name="rag-paper-query", metadata={"langfuse_prompt": langfuse_prompt}
)
chain = prompt_template | llm | output_parser
return chain.with_config(config)
89 changes: 81 additions & 8 deletions ai-backend/backend/src/ir_pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
create_answer_generation_chain,
create_query_expansion_chain,
create_rag_answer_generation_chain,
create_rag_paper_answer_generation_chain,
)
from backend.src.ir_pipeline.schema import LLMResponse, Terms
from backend.src.ir_pipeline.schema import LLMPaperResponse, LLMResponse, Terms
from backend.src.ir_pipeline.tools.inspire import (
InspireOSFullTextSearchTool,
InspireSearchTool,
Expand All @@ -19,9 +20,10 @@
format_refs,
)
from backend.src.ir_pipeline.utils.utils import timer
from backend.src.schemas.query import QueryResponse
from backend.src.schemas.query import QueryPaperResponse, QueryResponse
from backend.src.utils.embeddings import VLLMOpenAIEmbeddings
from backend.src.utils.reranker import CustomJinaRerank
from langchain.schema import Document
from langchain_community.llms import VLLMOpenAI
from langchain_community.vectorstores import OpenSearchVectorSearch
from langfuse.callback import CallbackHandler
Expand Down Expand Up @@ -121,6 +123,7 @@ def initialize_chains(model):
llm=llm, prompt_name="generate-answer-playground"
),
"answer_chain_rag": create_rag_answer_generation_chain(llm=llm),
"answer_chain_rag_paper": create_rag_paper_answer_generation_chain(llm=llm),
}


Expand Down Expand Up @@ -194,25 +197,53 @@ async def search_playground(query, model):
}


async def search_rag(query: str, model: str, user: str = None):
async def _rag_common(
query: str, model: str, user: str = None, control_number: int = None
):
initialize_rag_resources()
initialize_chains(model)

embedding_model = RESOURCE_CACHE["embedding_model"]
vector_store = RESOURCE_CACHE["vector_store"]
reranker = RESOURCE_CACHE["reranker"]
answer_chain = CHAIN_CACHE[model]["answer_chain_rag"]

config = create_langfuse_config(user)

with timer("RAG Embedding"):
query_embedding = embedding_model.embed_query(query)

with timer("RAG Retrieval"):
docs = vector_store.similarity_search_by_vector(
embedding=query_embedding,
k=25,
)
if control_number:
os_client = vector_store.client
index_name = vector_store.index_name

# We need a direct OpenSearch query as LangChain methods don't support
# filtering with empty queries
search_body = {
"query": {"term": {"metadata.control_number": control_number}},
"size": 25,
"_source": ["text", "metadata"],
}

try:
response = os_client.search(index=index_name, body=search_body)

docs = [
Document(
page_content=hit["_source"]["text"],
metadata=hit["_source"]["metadata"],
)
for hit in response["hits"]["hits"]
]

except Exception as e:
print(f"OpenSearch query failed: {e}")
docs = []
else:
docs = vector_store.similarity_search_by_vector(
embedding=query_embedding,
k=25,
)

with timer("RAG Reranking"):
ranked_docs = reranker.compress_documents(
Expand All @@ -222,6 +253,14 @@ async def search_rag(query: str, model: str, user: str = None):

context = format_docs(ranked_docs)

return ranked_docs, context, config


async def search_rag(query: str, model: str, user: str = None):
ranked_docs, context, config = await _rag_common(query, model, user)

answer_chain = CHAIN_CACHE[model]["answer_chain_rag"]

with timer("RAG LLM"):
response: LLMResponse = await answer_chain.ainvoke(
{"question": query, "context": context}, config=config
Expand All @@ -234,3 +273,37 @@ async def search_rag(query: str, model: str, user: str = None):
long_answer=formatted_response,
citations=citations,
)


async def search_rag_paper(
query: str,
model: str,
control_number: int,
user: str = None,
chat_history: list = None,
):
ranked_docs, context, config = await _rag_common(query, model, user, control_number)

answer_chain = CHAIN_CACHE[model]["answer_chain_rag_paper"]

chat_messages = []
if chat_history:
for msg in chat_history:
role = "user" if msg["type"] == "user" else "assistant"
chat_messages.append({"role": role, "content": msg["content"]})

with timer("RAG LLM"):
response: LLMPaperResponse = await answer_chain.ainvoke(
{
"question": query,
"context": context,
"history": chat_messages,
},
config=config,
)

formatted_response, _ = format_refs(response.response, ranked_docs)

return QueryPaperResponse(
long_answer=formatted_response,
)
4 changes: 4 additions & 0 deletions ai-backend/backend/src/ir_pipeline/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@ class LLMResponse(BaseModel):
brief: str


class LLMPaperResponse(BaseModel):
response: str


class Terms(BaseModel):
terms: list[str]
11 changes: 11 additions & 0 deletions ai-backend/backend/src/schemas/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
from pydantic import UUID4, BaseModel


class ChatMessage(BaseModel):
type: str # "user" or "assistant"
content: str


class QueryRequest(BaseModel):
query: str
model: str = getenv("LLM_MODEL")
user: Optional[str] = None
matomo_client_id: Optional[UUID4] = None
control_number: Optional[int] = None
history: Optional[List[ChatMessage]] = None


class Citation(BaseModel):
Expand All @@ -21,3 +28,7 @@ class QueryResponse(BaseModel):
brief_answer: str
long_answer: str
citations: List[Citation]


class QueryPaperResponse(BaseModel):
long_answer: str
8 changes: 5 additions & 3 deletions ai-backend/backend/src/utils/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
langfuse = Langfuse()


# Locally langfuse.environment should be unset (None)
def get_prompt(prompt_name: str):
def fetch_prompt(label: str = None):
return langfuse.get_prompt(
prompt_name,
cache_ttl_seconds=0 if langfuse.environment == "local" else None,
cache_ttl_seconds=0 if langfuse.environment is None else None,
label=label,
)

try:
langfuse_prompt = fetch_prompt(label=langfuse.environment)
label = "latest" if langfuse.environment is None else langfuse.environment
langfuse_prompt = fetch_prompt(label=label)
except NotFoundError:
logger.warning(
f"Prompt '{prompt_name}' or label '{langfuse.environment}' "
f"Prompt '{prompt_name}' or label '{label}' "
f"not found in Langfuse, trying label 'production'"
)
langfuse_prompt = fetch_prompt()
Expand Down
Loading