| |
| import json |
| import os |
| import re |
| import time |
| from typing import Any |
|
|
| from dotenv import load_dotenv |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import StreamingResponse |
| from huggingface_hub import InferenceClient |
| from pydantic import BaseModel, Field |
|
|
| from config_loader import cfg |
| from vector_db import get_index_by_name, load_chunks_with_local_cache |
| from retriever.retriever import HybridRetriever |
| from retriever.generator import RAGGenerator |
| from retriever.processor import ChunkProcessor |
|
|
| from models.llama_3_8b import Llama3_8B |
| from models.mistral_7b import Mistral_7b |
| from models.qwen_2_5 import Qwen2_5 |
| from models.deepseek_v3 import DeepSeek_V3 |
| from models.tiny_aya import TinyAya |
|
|
|
|
| |
|
|
|
|
| class PredictRequest(BaseModel): |
| query: str = Field(..., min_length=1, description="User query text") |
| model: str = Field(default="Llama-3-8B", description="Model name key") |
| top_k: int = Field(default=10, ge=1, le=50) |
| final_k: int = Field(default=5, ge=1, le=20) |
| mode: str = Field(default="hybrid", description="semantic | bm25 | hybrid") |
| rerank_strategy: str = Field(default="cross-encoder", description="cross-encoder | rrf | none") |
|
|
|
|
| class PredictResponse(BaseModel): |
| model: str |
| answer: str |
| contexts: list[str] |
| retrieved_chunks: list[dict[str, Any]] |
|
|
|
|
| class TitleRequest(BaseModel): |
| query: str = Field(..., min_length=1, description="First user message") |
|
|
|
|
| class TitleResponse(BaseModel): |
| title: str |
| source: str |
|
|
|
|
| def _to_ndjson(payload: dict[str, Any]) -> str: |
| return json.dumps(payload, ensure_ascii=False) + "\n" |
|
|
|
|
|
|
| |
| |
|
|
| def _title_from_query(query: str) -> str: |
| stop_words = { |
| "a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how", |
| "i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that", |
| "the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your", |
| } |
|
|
| words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query) |
| if not words: |
| return "New Chat" |
|
|
| filtered: list[str] = [] |
| for word in words: |
| cleaned = word.strip("-_/+") |
| if not cleaned: |
| continue |
| if cleaned.lower() in stop_words: |
| continue |
| filtered.append(cleaned) |
| if len(filtered) >= 6: |
| break |
|
|
| chosen = filtered if filtered else words[:6] |
| normalized = [w.capitalize() if w.islower() else w for w in chosen] |
| title = " ".join(normalized).strip() |
| return title[:80] if title else "New Chat" |
|
|
|
|
| |
| |
|
|
| def _clean_title_text(raw: str) -> str: |
| text = (raw or "").strip() |
| text = text.replace("\n", " ").replace("\r", " ") |
| text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text) |
| text = re.sub(r"\s+", " ", text).strip() |
| words = text.split() |
| if len(words) > 8: |
| text = " ".join(words[:8]) |
| return text[:80] |
|
|
|
|
| def _title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None: |
| system_prompt = ( |
| "You generate short chat titles. Return only a title, no punctuation at the end, no quotes." |
| ) |
| user_prompt = ( |
| "Create a concise 3-7 word title for this user request:\n" |
| f"{query}" |
| ) |
|
|
| response = client.chat_completion( |
| model=model_id, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| max_tokens=24, |
| temperature=0.3, |
| ) |
| if not response or not response.choices: |
| return None |
|
|
| raw_title = response.choices[0].message.content or "" |
| title = _clean_title_text(raw_title) |
| if not title or title.lower() == "new chat": |
| return None |
| return title |
|
|
|
|
| def _parse_title_model_candidates() -> list[str]: |
| raw = os.getenv( |
| "TITLE_MODEL_IDS", |
| "Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct", |
| ) |
| models = [m.strip() for m in raw.split(",") if m.strip()] |
| return models or ["meta-llama/Meta-Llama-3-8B-Instruct"] |
|
|
|
|
| def _build_retrieved_chunks( |
| contexts: list[str], |
| chunk_lookup: dict[str, dict[str, Any]], |
| ) -> list[dict[str, Any]]: |
| if not contexts: |
| return [] |
|
|
| retrieved_chunks: list[dict[str, Any]] = [] |
|
|
| for idx, text in enumerate(contexts, start=1): |
| meta = chunk_lookup.get(text, {}) |
| title = meta.get("title") or "Untitled" |
| url = meta.get("url") or "" |
| chunk_index = meta.get("chunk_index") |
|
|
| retrieved_chunks.append( |
| { |
| "rank": idx, |
| "text": text, |
| "source_title": title, |
| "source_url": url, |
| "chunk_index": chunk_index, |
| } |
| ) |
|
|
| return retrieved_chunks |
|
|
|
|
|
|
| |
| |
| |
|
|
| app = FastAPI(title="RAG-AS3 API", version="0.1.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| state: dict[str, Any] = {} |
|
|
|
|
| def _build_models(hf_token: str) -> dict[str, Any]: |
| return { |
| "Llama-3-8B": Llama3_8B(token=hf_token), |
| "Mistral-7B": Mistral_7b(token=hf_token), |
| "Qwen-2.5": Qwen2_5(token=hf_token), |
| "DeepSeek-V3": DeepSeek_V3(token=hf_token), |
| "TinyAya": TinyAya(token=hf_token), |
| } |
|
|
|
|
| def _resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]: |
| aliases = { |
| "llama": "Llama-3-8B", |
| "mistral": "Mistral-7B", |
| "qwen": "Qwen-2.5", |
| "deepseek": "DeepSeek-V3", |
| "tinyaya": "TinyAya", |
| } |
| model_key = aliases.get(name.lower(), name) |
| if model_key not in models: |
| allowed = ", ".join(models.keys()) |
| raise HTTPException(status_code=400, detail=f"Unknown model '{name}'. Use one of: {allowed}") |
| return model_key, models[model_key] |
|
|
|
|
| |
| |
| |
|
|
| |
| |
|
|
| @app.on_event("startup") |
| def startup_event() -> None: |
| startup_start = time.perf_counter() |
|
|
| dotenv_start = time.perf_counter() |
| load_dotenv() |
| dotenv_time = time.perf_counter() - dotenv_start |
|
|
| env_start = time.perf_counter() |
| hf_token = os.getenv("HF_TOKEN") |
| pinecone_api_key = os.getenv("PINECONE_API_KEY") |
| env_time = time.perf_counter() - env_start |
|
|
| if not pinecone_api_key: |
| raise RuntimeError("PINECONE_API_KEY not found in environment variables") |
| if not hf_token: |
| raise RuntimeError("HF_TOKEN not found in environment variables") |
|
|
| index_name = "cbt-book-recursive" |
| |
| |
| embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2") |
| project_root = os.path.dirname(os.path.abspath(__file__)) |
| cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, ".cache")) |
| force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"} |
|
|
| index_start = time.perf_counter() |
| index = get_index_by_name( |
| api_key=pinecone_api_key, |
| index_name=index_name |
| ) |
| index_time = time.perf_counter() - index_start |
|
|
| chunks_start = time.perf_counter() |
| final_chunks, chunk_source = load_chunks_with_local_cache( |
| index=index, |
| index_name=index_name, |
| cache_dir=cache_dir, |
| batch_size=100, |
| force_refresh=force_cache_refresh, |
| ) |
| chunk_load_time = time.perf_counter() - chunks_start |
|
|
| if not final_chunks: |
| raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.") |
|
|
| processor_start = time.perf_counter() |
| proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False) |
| processor_time = time.perf_counter() - processor_start |
|
|
| retriever_start = time.perf_counter() |
| retriever = HybridRetriever(final_chunks, proc.encoder, verbose=False) |
| retriever_time = time.perf_counter() - retriever_start |
|
|
| rag_start = time.perf_counter() |
| rag_engine = RAGGenerator() |
| rag_time = time.perf_counter() - rag_start |
|
|
| models_start = time.perf_counter() |
| models = _build_models(hf_token) |
| models_time = time.perf_counter() - models_start |
|
|
| state_start = time.perf_counter() |
| chunk_lookup: dict[str, dict[str, Any]] = {} |
| for chunk in final_chunks: |
| metadata = chunk.get("metadata", {}) |
| text = metadata.get("text") |
| if not text or text in chunk_lookup: |
| continue |
| chunk_lookup[text] = { |
| "title": metadata.get("title", "Untitled"), |
| "url": metadata.get("url", ""), |
| "chunk_index": metadata.get("chunk_index"), |
| } |
|
|
| state["index"] = index |
| state["retriever"] = retriever |
| state["rag_engine"] = rag_engine |
| state["models"] = models |
| state["chunk_lookup"] = chunk_lookup |
| state["title_model_ids"] = _parse_title_model_candidates() |
| state["title_client"] = InferenceClient(token=hf_token) |
| state_time = time.perf_counter() - state_start |
|
|
| startup_time = time.perf_counter() - startup_start |
| print( |
| f"API startup complete | chunks={len(final_chunks)} | " |
| f"dotenv={dotenv_time:.3f}s | " |
| f"env={env_time:.3f}s | " |
| f"index={index_time:.3f}s | " |
| f"cache_dir={cache_dir} | " |
| f"force_cache_refresh={force_cache_refresh} | " |
| f"chunk_source={chunk_source} | " |
| f"chunk_load={chunk_load_time:.3f}s | " |
| f"processor={processor_time:.3f}s | " |
| f"retriever={retriever_time:.3f}s | " |
| f"rag={rag_time:.3f}s | " |
| f"models={models_time:.3f}s | " |
| f"state={state_time:.3f}s | " |
| f"total={startup_time:.3f}s" |
| ) |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, str]: |
| ready = all(k in state for k in ("index", "retriever", "rag_engine", "models")) |
| return {"status": "ok" if ready else "starting"} |
|
|
|
|
| |
| |
| @app.post("/predict/title", response_model=TitleResponse) |
| def suggest_title(payload: TitleRequest) -> TitleResponse: |
| query = payload.query.strip() |
| if not query: |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
|
|
| fallback_title = _title_from_query(query) |
|
|
| title_client: InferenceClient | None = state.get("title_client") |
| title_model_ids: list[str] = state.get("title_model_ids", _parse_title_model_candidates()) |
|
|
| if title_client is not None: |
| for title_model_id in title_model_ids: |
| try: |
| hf_title = _title_from_hf(query, title_client, title_model_id) |
| if hf_title: |
| return TitleResponse(title=hf_title, source=f"hf:{title_model_id}") |
| except Exception as exc: |
| err_text = str(exc) |
| |
| if "model_not_supported" in err_text or "not supported by any provider" in err_text: |
| continue |
| print(f"Title generation model failed ({title_model_id}): {exc}") |
| continue |
|
|
| print("Title generation fallback triggered: no title model available/successful") |
|
|
| return TitleResponse(title=fallback_title, source="rule-based") |
|
|
|
|
|
|
| |
| |
| |
| @app.post("/predict", response_model=PredictResponse) |
| def predict(payload: PredictRequest) -> PredictResponse: |
| req_start = time.perf_counter() |
|
|
| precheck_start = time.perf_counter() |
| if not state: |
| raise HTTPException(status_code=503, detail="Service not initialized yet") |
|
|
| query = payload.query.strip() |
| if not query: |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
| precheck_time = time.perf_counter() - precheck_start |
|
|
| state_access_start = time.perf_counter() |
| retriever: HybridRetriever = state["retriever"] |
| index = state["index"] |
| rag_engine: RAGGenerator = state["rag_engine"] |
| models: dict[str, Any] = state["models"] |
| chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) |
| state_access_time = time.perf_counter() - state_access_start |
|
|
| model_resolve_start = time.perf_counter() |
| model_name, model_instance = _resolve_model(payload.model, models) |
| model_resolve_time = time.perf_counter() - model_resolve_start |
|
|
| retrieval_start = time.perf_counter() |
| contexts = retriever.search( |
| query, |
| index, |
| mode=payload.mode, |
| rerank_strategy=payload.rerank_strategy, |
| use_mmr=True, |
| top_k=payload.top_k, |
| final_k=payload.final_k, |
| verbose=False, |
| ) |
| retrieval_time = time.perf_counter() - retrieval_start |
|
|
| if not contexts: |
| raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") |
|
|
| inference_start = time.perf_counter() |
| answer = rag_engine.get_answer(model_instance, query, contexts, temperature=0.1) |
| inference_time = time.perf_counter() - inference_start |
|
|
| mapping_start = time.perf_counter() |
| retrieved_chunks = _build_retrieved_chunks( |
| contexts=contexts, |
| chunk_lookup=chunk_lookup, |
| ) |
| mapping_time = time.perf_counter() - mapping_start |
|
|
| total_time = time.perf_counter() - req_start |
|
|
| print( |
| f"Predict timing | model={model_name} | mode={payload.mode} | " |
| f"rerank={payload.rerank_strategy} | precheck={precheck_time:.3f}s | " |
| f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | " |
| f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | " |
| f"context_map={mapping_time:.3f}s | total={total_time:.3f}s" |
| ) |
|
|
| return PredictResponse( |
| model=model_name, |
| answer=answer, |
| contexts=contexts, |
| retrieved_chunks=retrieved_chunks, |
| ) |
|
|
| |
| @app.post("/predict/stream") |
| def predict_stream(payload: PredictRequest) -> StreamingResponse: |
| req_start = time.perf_counter() |
|
|
| precheck_start = time.perf_counter() |
| if not state: |
| raise HTTPException(status_code=503, detail="Service not initialized yet") |
|
|
| query = payload.query.strip() |
| if not query: |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
| precheck_time = time.perf_counter() - precheck_start |
|
|
| state_access_start = time.perf_counter() |
| retriever: HybridRetriever = state["retriever"] |
| index = state["index"] |
| rag_engine: RAGGenerator = state["rag_engine"] |
| models: dict[str, Any] = state["models"] |
| chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {}) |
| state_access_time = time.perf_counter() - state_access_start |
|
|
| model_resolve_start = time.perf_counter() |
| model_name, model_instance = _resolve_model(payload.model, models) |
| model_resolve_time = time.perf_counter() - model_resolve_start |
|
|
| retrieval_start = time.perf_counter() |
| contexts = retriever.search( |
| query, |
| index, |
| mode=payload.mode, |
| rerank_strategy=payload.rerank_strategy, |
| use_mmr=True, |
| top_k=payload.top_k, |
| final_k=payload.final_k, |
| verbose=False, |
| ) |
| retrieval_time = time.perf_counter() - retrieval_start |
|
|
| if not contexts: |
| raise HTTPException(status_code=404, detail="No context chunks retrieved for this query") |
|
|
| def stream_events(): |
| inference_start = time.perf_counter() |
| answer_parts: list[str] = [] |
| try: |
| for token in rag_engine.get_answer_stream(model_instance, query, contexts, temperature=0.1): |
| answer_parts.append(token) |
| yield _to_ndjson({"type": "token", "token": token}) |
|
|
| inference_time = time.perf_counter() - inference_start |
| answer = "".join(answer_parts) |
| retrieved_chunks = _build_retrieved_chunks( |
| contexts=contexts, |
| chunk_lookup=chunk_lookup, |
| ) |
|
|
| yield _to_ndjson( |
| { |
| "type": "done", |
| "model": model_name, |
| "answer": answer, |
| "contexts": contexts, |
| "retrieved_chunks": retrieved_chunks, |
| } |
| ) |
| except Exception as exc: |
| yield _to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"}) |
|
|
| return StreamingResponse( |
| stream_events(), |
| media_type="application/x-ndjson", |
| headers={ |
| "Cache-Control": "no-cache", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|