|
""" |
|
This module contains all query-related routes for the LightRAG API. |
|
""" |
|
|
|
import json |
|
import logging |
|
from typing import Any, Dict, List, Literal, Optional |
|
|
|
from fastapi import APIRouter, Depends, HTTPException |
|
from lightrag.base import QueryParam |
|
from ..utils_api import get_combined_auth_dependency |
|
from pydantic import BaseModel, Field, field_validator |
|
|
|
from ascii_colors import trace_exception |
|
|
|
router = APIRouter(tags=["query"]) |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
query: str = Field( |
|
min_length=1, |
|
description="The query text", |
|
) |
|
|
|
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( |
|
default="hybrid", |
|
description="Query mode", |
|
) |
|
|
|
only_need_context: Optional[bool] = Field( |
|
default=None, |
|
description="If True, only returns the retrieved context without generating a response.", |
|
) |
|
|
|
only_need_prompt: Optional[bool] = Field( |
|
default=None, |
|
description="If True, only returns the generated prompt without producing a response.", |
|
) |
|
|
|
response_type: Optional[str] = Field( |
|
min_length=1, |
|
default=None, |
|
description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", |
|
) |
|
|
|
top_k: Optional[int] = Field( |
|
ge=1, |
|
default=None, |
|
description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", |
|
) |
|
|
|
max_token_for_text_unit: Optional[int] = Field( |
|
gt=1, |
|
default=None, |
|
description="Maximum number of tokens allowed for each retrieved text chunk.", |
|
) |
|
|
|
max_token_for_global_context: Optional[int] = Field( |
|
gt=1, |
|
default=None, |
|
description="Maximum number of tokens allocated for relationship descriptions in global retrieval.", |
|
) |
|
|
|
max_token_for_local_context: Optional[int] = Field( |
|
gt=1, |
|
default=None, |
|
description="Maximum number of tokens allocated for entity descriptions in local retrieval.", |
|
) |
|
|
|
hl_keywords: Optional[List[str]] = Field( |
|
default=None, |
|
description="List of high-level keywords to prioritize in retrieval.", |
|
) |
|
|
|
ll_keywords: Optional[List[str]] = Field( |
|
default=None, |
|
description="List of low-level keywords to refine retrieval focus.", |
|
) |
|
|
|
conversation_history: Optional[List[Dict[str, Any]]] = Field( |
|
default=None, |
|
description="Stores past conversation history to maintain context. Format: [{'role': 'user/assistant', 'content': 'message'}].", |
|
) |
|
|
|
history_turns: Optional[int] = Field( |
|
ge=0, |
|
default=None, |
|
description="Number of complete conversation turns (user-assistant pairs) to consider in the response context.", |
|
) |
|
|
|
@field_validator("query", mode="after") |
|
@classmethod |
|
def query_strip_after(cls, query: str) -> str: |
|
return query.strip() |
|
|
|
@field_validator("hl_keywords", mode="after") |
|
@classmethod |
|
def hl_keywords_strip_after(cls, hl_keywords: List[str] | None) -> List[str] | None: |
|
if hl_keywords is None: |
|
return None |
|
return [keyword.strip() for keyword in hl_keywords] |
|
|
|
@field_validator("ll_keywords", mode="after") |
|
@classmethod |
|
def ll_keywords_strip_after(cls, ll_keywords: List[str] | None) -> List[str] | None: |
|
if ll_keywords is None: |
|
return None |
|
return [keyword.strip() for keyword in ll_keywords] |
|
|
|
@field_validator("conversation_history", mode="after") |
|
@classmethod |
|
def conversation_history_role_check( |
|
cls, conversation_history: List[Dict[str, Any]] | None |
|
) -> List[Dict[str, Any]] | None: |
|
if conversation_history is None: |
|
return None |
|
for msg in conversation_history: |
|
if "role" not in msg or msg["role"] not in {"user", "assistant"}: |
|
raise ValueError( |
|
"Each message must have a 'role' key with value 'user' or 'assistant'." |
|
) |
|
return conversation_history |
|
|
|
def to_query_params(self, is_stream: bool) -> "QueryParam": |
|
"""Converts a QueryRequest instance into a QueryParam instance.""" |
|
|
|
request_data = self.model_dump(exclude_none=True, exclude={"query"}) |
|
|
|
|
|
param = QueryParam(**request_data) |
|
param.stream = is_stream |
|
return param |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
response: str = Field( |
|
description="The generated response", |
|
) |
|
|
|
|
|
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): |
|
combined_auth = get_combined_auth_dependency(api_key) |
|
|
|
@router.post( |
|
"/query", response_model=QueryResponse, dependencies=[Depends(combined_auth)] |
|
) |
|
async def query_text(request: QueryRequest): |
|
""" |
|
Handle a POST request at the /query endpoint to process user queries using RAG capabilities. |
|
|
|
Parameters: |
|
request (QueryRequest): The request object containing the query parameters. |
|
Returns: |
|
QueryResponse: A Pydantic model containing the result of the query processing. |
|
If a string is returned (e.g., cache hit), it's directly returned. |
|
Otherwise, an async generator may be used to build the response. |
|
|
|
Raises: |
|
HTTPException: Raised when an error occurs during the request handling process, |
|
with status code 500 and detail containing the exception message. |
|
""" |
|
try: |
|
param = request.to_query_params(False) |
|
response = await rag.aquery(request.query, param=param) |
|
|
|
|
|
if isinstance(response, str): |
|
return QueryResponse(response=response) |
|
|
|
if isinstance(response, dict): |
|
result = json.dumps(response, indent=2) |
|
return QueryResponse(response=result) |
|
else: |
|
return QueryResponse(response=str(response)) |
|
except Exception as e: |
|
trace_exception(e) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@router.post("/query/stream", dependencies=[Depends(combined_auth)]) |
|
async def query_text_stream(request: QueryRequest): |
|
""" |
|
This endpoint performs a retrieval-augmented generation (RAG) query and streams the response. |
|
|
|
Args: |
|
request (QueryRequest): The request object containing the query parameters. |
|
optional_api_key (Optional[str], optional): An optional API key for authentication. Defaults to None. |
|
|
|
Returns: |
|
StreamingResponse: A streaming response containing the RAG query results. |
|
""" |
|
try: |
|
param = request.to_query_params(True) |
|
response = await rag.aquery(request.query, param=param) |
|
|
|
from fastapi.responses import StreamingResponse |
|
|
|
async def stream_generator(): |
|
if isinstance(response, str): |
|
|
|
yield f"{json.dumps({'response': response})}\n" |
|
else: |
|
|
|
try: |
|
async for chunk in response: |
|
if chunk: |
|
yield f"{json.dumps({'response': chunk})}\n" |
|
except Exception as e: |
|
logging.error(f"Streaming error: {str(e)}") |
|
yield f"{json.dumps({'error': str(e)})}\n" |
|
|
|
return StreamingResponse( |
|
stream_generator(), |
|
media_type="application/x-ndjson", |
|
headers={ |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
"Content-Type": "application/x-ndjson", |
|
"X-Accel-Buffering": "no", |
|
}, |
|
) |
|
except Exception as e: |
|
trace_exception(e) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return router |
|
|