Spaces:
Sleeping
Sleeping
# app/api.py | |
from __future__ import annotations | |
import os | |
import re | |
from collections import deque | |
from datetime import datetime, timezone | |
from time import perf_counter | |
from typing import List, Optional, Dict, Any | |
import faiss | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from pydantic import BaseModel, Field | |
from .rag_system import SimpleRAG, UPLOAD_DIR, INDEX_DIR | |
__version__ = "1.3.1" | |
app = FastAPI(title="RAG API", version=__version__) | |
# CORS (Streamlit UI üçün) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
rag = SimpleRAG() | |
# -------------------- Schemas -------------------- | |
class UploadResponse(BaseModel): | |
filename: str | |
chunks_added: int | |
class AskRequest(BaseModel): | |
question: str = Field(..., min_length=1) | |
top_k: int = Field(5, ge=1, le=20) | |
class AskResponse(BaseModel): | |
answer: str | |
contexts: List[str] | |
class HistoryItem(BaseModel): | |
question: str | |
timestamp: str | |
class HistoryResponse(BaseModel): | |
total_chunks: int | |
history: List[HistoryItem] = [] | |
# -------------------- Stats (in-memory) -------------------- | |
class StatsStore: | |
def __init__(self): | |
self.documents_indexed = 0 | |
self.questions_answered = 0 | |
self.latencies_ms = deque(maxlen=500) | |
self.last7_questions = deque([0] * 7, maxlen=7) # sadə günlük sayğac | |
self.history = deque(maxlen=50) | |
def add_docs(self, n: int): | |
if n > 0: | |
self.documents_indexed += int(n) | |
def add_question(self, latency_ms: Optional[int] = None, q: Optional[str] = None): | |
self.questions_answered += 1 | |
if latency_ms is not None: | |
self.latencies_ms.append(int(latency_ms)) | |
if len(self.last7_questions) == 7: | |
self.last7_questions[0] += 1 | |
if q: | |
self.history.appendleft( | |
{"question": q, "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds")} | |
) | |
def avg_ms(self) -> int: | |
return int(sum(self.latencies_ms) / len(self.latencies_ms)) if self.latencies_ms else 0 | |
stats = StatsStore() | |
# -------------------- Helpers -------------------- | |
_STOPWORDS = { | |
"the","a","an","of","for","and","or","in","on","to","from","with","by","is","are", | |
"was","were","be","been","being","at","as","that","this","these","those","it","its", | |
"into","than","then","so","such","about","over","per","via","vs","within" | |
} | |
def _tokenize(s: str) -> List[str]: | |
return [w for w in re.findall(r"[a-zA-Z0-9]+", s.lower()) if w and w not in _STOPWORDS and len(w) > 2] | |
def _is_generic_answer(text: str) -> bool: | |
if not text: | |
return True | |
low = text.strip().lower() | |
if len(low) < 15: | |
return True | |
# tipik generik pattern-lər | |
if "based on document context" in low or "appears to be" in low: | |
return True | |
return False | |
def _extractive_fallback(question: str, contexts: List[str], max_chars: int = 600) -> str: | |
""" Sualın açar sözlərinə əsasən kontekstdən cümlələr seç. """ | |
if not contexts: | |
return "I couldn't find relevant information in the indexed documents for this question." | |
qtok = set(_tokenize(question)) | |
if not qtok: | |
return (contexts[0] or "")[:max_chars] | |
# cümlələrə böl və skorla | |
sentences: List[str] = [] | |
for c in contexts: | |
for s in re.split(r"(?<=[\.!\?])\s+|\n+", (c or "").strip()): | |
s = s.strip() | |
if s: | |
sentences.append(s) | |
scored: List[tuple[int, str]] = [] | |
for s in sentences: | |
st = set(_tokenize(s)) | |
scored.append((len(qtok & st), s)) | |
scored.sort(key=lambda x: (x[0], len(x[1])), reverse=True) | |
picked: List[str] = [] | |
for sc, s in scored: | |
if sc <= 0 and picked: | |
break | |
if len((" ".join(picked) + " " + s).strip()) > max_chars: | |
break | |
picked.append(s) | |
if not picked: | |
return (contexts[0] or "")[:max_chars] | |
bullets = "\n".join(f"- {p}" for p in picked) | |
return f"Answer (based on document context):\n{bullets}" | |
# -------------------- Routes -------------------- | |
def root(): | |
return RedirectResponse(url="/docs") | |
def health(): | |
return { | |
"status": "ok", | |
"version": app.version, | |
"summarizer": "extractive_en + translate + keyword_fallback", | |
"faiss_ntotal": int(getattr(rag.index, "ntotal", 0)), | |
"model_dim": int(getattr(rag.index, "d", rag.embed_dim)), | |
} | |
def debug_translate(): | |
try: | |
from transformers import pipeline | |
tr = pipeline("translation", model="Helsinki-NLP/opus-mt-az-en", cache_dir=str(rag.cache_dir), device=-1) | |
out = tr("Sənəd təmiri və quraşdırılması ilə bağlı işlər görülüb.", max_length=80)[0]["translation_text"] | |
return {"ok": True, "example_out": out} | |
except Exception as e: | |
return JSONResponse(status_code=500, content={"ok": False, "error": str(e)}) | |
async def upload_pdf(file: UploadFile = File(...)): | |
if not file.filename.lower().endswith(".pdf"): | |
raise HTTPException(status_code=400, detail="Only PDF files are allowed.") | |
dest = UPLOAD_DIR / file.filename | |
with open(dest, "wb") as f: | |
while True: | |
chunk = await file.read(1024 * 1024) | |
if not chunk: | |
break | |
f.write(chunk) | |
added = rag.add_pdf(dest) | |
if added == 0: | |
raise HTTPException(status_code=400, detail="No extractable text found (likely a scanned image PDF).") | |
stats.add_docs(added) | |
return UploadResponse(filename=file.filename, chunks_added=added) | |
def ask_question(payload: AskRequest): | |
q = (payload.question or "").strip() | |
if not q: | |
raise HTTPException(status_code=400, detail="Missing 'question'.") | |
k = max(1, int(payload.top_k)) | |
t0 = perf_counter() | |
# 1) Həmişə sual embedding-i ilə axtar | |
try: | |
hits = rag.search(q, k=k) # List[Tuple[text, score]] | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Search failed: {e}") | |
contexts = [c for c, _ in (hits or []) if c] or (getattr(rag, "last_added", [])[:k] if getattr(rag, "last_added", None) else []) | |
if not contexts: | |
latency_ms = int((perf_counter() - t0) * 1000) | |
stats.add_question(latency_ms, q=q) | |
return AskResponse( | |
answer="I couldn't find relevant information in the indexed documents for this question.", | |
contexts=[] | |
) | |
# 2) Cavabı sintez et (rag içində LLM/rule-based ola bilər) | |
try: | |
synthesized = (rag.synthesize_answer(q, contexts) or "").strip() | |
except Exception: | |
synthesized = "" | |
# 3) Generic görünürsə, extractive fallback | |
if _is_generic_answer(synthesized): | |
synthesized = _extractive_fallback(q, contexts, max_chars=600) | |
latency_ms = int((perf_counter() - t0) * 1000) | |
stats.add_question(latency_ms, q=q) | |
return AskResponse(answer=synthesized, contexts=contexts) | |
def get_history(): | |
return HistoryResponse( | |
total_chunks=len(rag.chunks), | |
history=[HistoryItem(**h) for h in list(stats.history)] | |
) | |
def stats_endpoint(): | |
return { | |
"documents_indexed": stats.documents_indexed, | |
"questions_answered": stats.questions_answered, | |
"avg_ms": stats.avg_ms, | |
"last7_questions": list(stats.last7_questions), | |
"total_chunks": len(rag.chunks), | |
"faiss_ntotal": int(getattr(rag.index, "ntotal", 0)), | |
"model_dim": int(getattr(rag.index, "d", rag.embed_dim)), | |
"last_added_chunks": len(getattr(rag, "last_added", [])), | |
"version": app.version, | |
} | |
def reset_index(): | |
try: | |
rag.index = faiss.IndexFlatIP(rag.embed_dim) | |
rag.chunks = [] | |
rag.last_added = [] | |
for p in [INDEX_DIR / "faiss.index", INDEX_DIR / "meta.npy"]: | |
try: | |
os.remove(p) | |
except FileNotFoundError: | |
pass | |
stats.documents_indexed = 0 | |
stats.questions_answered = 0 | |
stats.latencies_ms.clear() | |
stats.last7_questions = deque([0] * 7, maxlen=7) | |
stats.history.clear() | |
return {"ok": True} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |