Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import logging | |
| from typing import List, Dict, Optional | |
| from datetime import datetime | |
| from pymongo import MongoClient, UpdateOne | |
| from pymongo.operations import SearchIndexModel | |
| from sentence_transformers import SentenceTransformer | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| # Config | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| DB_NAME = os.getenv("MONGO_DB_NAME", "ragdb") | |
| COLLECTION_NAME = "rag_chunks" | |
| VECTOR_INDEX = "vector_index" | |
| EMBEDDING_DIM = 384 | |
| # retrival | |
| TOP_K = 1 | |
| MIN_SCORE = 0.50 | |
| MAX_CHUNK_CHARS = 3000 | |
| # Context / prompt | |
| MAX_CONTEXT_CHARS = 1200 | |
| CHUNK_TRUNCATE = 1000 | |
| MAX_HISTORY_TURNS = 1 | |
| HISTORY_MSG_CAP = 150 | |
| # Generation | |
| MAX_OUTPUT_TOKENS = 256 | |
| TEMPERATURE = 0.2 | |
| # Retry | |
| MAX_RETRIES = 3 | |
| RETRY_BASE_WAIT = 1.0 | |
| _embedder: Optional[SentenceTransformer] = None | |
| _mongo_client: Optional[MongoClient] = None | |
| _collection = None | |
| def get_embedder() -> SentenceTransformer: | |
| global _embedder | |
| if _embedder is None: | |
| logger.info("Loading sentence-transformer model...") | |
| _embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| logger.info("Model loaded.") | |
| return _embedder | |
| def get_collection(): | |
| global _mongo_client, _collection | |
| if _collection is None: | |
| _mongo_client = MongoClient(MONGO_URI) | |
| db = _mongo_client[DB_NAME] | |
| _collection = db[COLLECTION_NAME] | |
| _ensure_vector_index(_collection) | |
| return _collection | |
| # ── Vector Index ────────────────────────────────────────────────────────────── | |
| def _ensure_vector_index(col): | |
| try: | |
| db = col.database | |
| if COLLECTION_NAME not in db.list_collection_names(): | |
| db.create_collection(COLLECTION_NAME) | |
| logger.info(f"Created {COLLECTION_NAME} collection.") | |
| existing = list(col.list_search_indexes()) | |
| names = [idx.get("name") for idx in existing] | |
| if VECTOR_INDEX not in names: | |
| logger.info("Creating Atlas Vector Search index...") | |
| index_def = SearchIndexModel( | |
| definition={ | |
| "fields": [ | |
| { | |
| "type": "vector", | |
| "path": "embedding", | |
| "numDimensions": EMBEDDING_DIM, | |
| "similarity": "cosine", | |
| }, | |
| { | |
| "type": "filter", | |
| "path": "document_id", | |
| }, | |
| ] | |
| }, | |
| name=VECTOR_INDEX, | |
| type="vectorSearch", | |
| ) | |
| col.create_search_index(index_def) | |
| logger.info("Vector index created. It may take ~1 min to become active.") | |
| else: | |
| logger.info("Vector index already exists.") | |
| except Exception as e: | |
| logger.warning( | |
| f"Could not verify/create vector index: {e}. " | |
| "Ensure your Atlas cluster supports Vector Search." | |
| ) | |
| # Embedding | |
| def embed_texts(texts: List[str]) -> List[List[float]]: | |
| model = get_embedder() | |
| vecs = model.encode( | |
| texts, batch_size=32, show_progress_bar=False, normalize_embeddings=True | |
| ) | |
| return vecs.tolist() | |
| def embed_query(text: str) -> List[float]: | |
| return embed_texts([text])[0] | |
| #Ingest | |
| def ingest_chunks(chunks: List[Dict], document_id: str) -> int: | |
| """ | |
| Embed each chunk and upsert into MongoDB Atlas. | |
| chunks : list of dicts from reportCleaning.process_pdf_to_chunks() | |
| Expected keys: chunk_id, text, section, chunk_index, | |
| total_chunks, pages, token_count, metadata | |
| document_id : unique identifier (e.g. file._id from MongoDB) | |
| Returns number of chunks stored. | |
| """ | |
| col = get_collection() | |
| texts = [c["text"] for c in chunks] | |
| logger.info(f"Embedding {len(texts)} chunks for document '{document_id}'...") | |
| embeddings = embed_texts(texts) | |
| logger.info("Embeddings done.") | |
| ops = [] | |
| for chunk, emb in zip(chunks, embeddings): | |
| doc = { | |
| "document_id": document_id, | |
| "chunk_id": chunk["chunk_id"], | |
| "section": chunk.get("section", ""), | |
| "chunk_index": chunk.get("chunk_index", 0), | |
| "total_chunks": chunk.get("total_chunks", 1), | |
| "pages": chunk.get("pages", []), | |
| "token_count": chunk.get("token_count", 0), | |
| "text": chunk["text"], | |
| "embedding": emb, | |
| "source_file": chunk.get("metadata", {}).get("source_file", ""), | |
| "created_at": datetime.utcnow(), | |
| } | |
| ops.append( | |
| UpdateOne( | |
| {"document_id": document_id, "chunk_id": chunk["chunk_id"]}, | |
| {"$set": doc}, | |
| upsert=True, | |
| ) | |
| ) | |
| if ops: | |
| result = col.bulk_write(ops, ordered=False) | |
| logger.info( | |
| f"Upserted {result.upserted_count + result.modified_count} chunks." | |
| ) | |
| print("ingested chunks in ragService") | |
| return len(ops) | |
| def delete_document_chunks(document_id: str) -> int: | |
| col = get_collection() | |
| result = col.delete_many({"document_id": document_id}) | |
| return result.deleted_count | |
| # Retrieval | |
| def retrieve_chunks( | |
| query: str, document_id: str, top_k: int = TOP_K | |
| ) -> List[Dict]: | |
| col = get_collection() | |
| q_emb = embed_query(query) | |
| pipeline = [ | |
| { | |
| "$vectorSearch": { | |
| "index": VECTOR_INDEX, | |
| "path": "embedding", | |
| "queryVector": q_emb, | |
| "numCandidates": top_k * 4, | |
| "limit": top_k, | |
| "filter": {"document_id": {"$eq": document_id}}, | |
| } | |
| }, | |
| { | |
| "$project": { | |
| "_id": 0, | |
| "text": 1, | |
| "section": 1, | |
| "pages": 1, | |
| "chunk_index": 1, | |
| "score": {"$meta": "vectorSearchScore"}, | |
| } | |
| }, | |
| ] | |
| try: | |
| results = list(col.aggregate(pipeline)) | |
| #removing low confidence chunk | |
| results = [r for r in results if r.get("score", 0) >= MIN_SCORE] | |
| results = [ | |
| r for r in results if len(r.get("text", "")) <= MAX_CHUNK_CHARS | |
| ] | |
| if not results: | |
| logger.warning( | |
| "All retrieved chunks failed quality gate " | |
| "(low score or oversized). Falling back to text search." | |
| ) | |
| return _fallback_text_search(query, document_id, top_k) | |
| logger.info( | |
| f"Retrieved {len(results)} chunk(s) " | |
| f"(score ≥ {MIN_SCORE}, size ≤ {MAX_CHUNK_CHARS} chars)" | |
| ) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Vector search failed: {e}") | |
| return _fallback_text_search(query, document_id, top_k) | |
| def _fallback_text_search( | |
| query: str, document_id: str, top_k: int | |
| ) -> List[Dict]: | |
| """Keyword fallback when the vector index isn't ready yet.""" | |
| col = get_collection() | |
| words = query.split()[:5] | |
| regex_part = "|".join(words) | |
| docs = list( | |
| col.find( | |
| { | |
| "document_id": document_id, | |
| "text": {"$regex": regex_part, "$options": "i"}, | |
| }, | |
| {"_id": 0, "text": 1, "section": 1, "pages": 1, "chunk_index": 1}, | |
| ).limit(top_k) | |
| ) | |
| logger.info(f"Fallback text search returned {len(docs)} result(s).") | |
| return docs | |
| # sContext Builder | |
| def build_context( | |
| chunks: List[Dict], max_chars: int = MAX_CONTEXT_CHARS | |
| ) -> str: | |
| """ | |
| Build a compact context string from retrieved chunks. | |
| Each chunk is individually truncated, then the total is capped at max_chars. | |
| """ | |
| if not chunks: | |
| return "No relevant context found." | |
| parts = [] | |
| total_chars = 0 | |
| for i, c in enumerate(chunks, 1): | |
| section = c.get("section", "Unknown Section") | |
| pages = c.get("pages", []) | |
| page_str = f"p.{pages[0]}" if pages else "" | |
| text = c["text"].strip() | |
| # Truncate large chunks | |
| if len(text) > CHUNK_TRUNCATE: | |
| text = text[:CHUNK_TRUNCATE] + "…[truncated]" | |
| entry = f"[{i} | {section} {page_str}]\n{text}" | |
| if total_chars + len(entry) > max_chars: | |
| logger.info(f"Context budget reached at chunk {i}, stopping.") | |
| break | |
| parts.append(entry) | |
| total_chars += len(entry) | |
| logger.info(f"Built context: {total_chars} chars from {len(parts)} chunk(s)") | |
| return "\n\n---\n\n".join(parts) | |
| GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite") | |
| def _get_gemini() -> genai.GenerativeModel: | |
| if not GEMINI_API_KEY: | |
| raise ValueError("GEMINI_API_KEY not set in environment.") | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| return genai.GenerativeModel(GEMINI_MODEL) | |
| SYSTEM_PROMPT = ( | |
| "Answer using ONLY the document context below. " | |
| "If the answer isn't there, say: \"I couldn't find that in this document.\" " | |
| "Cite section name or page number when available. " | |
| "Be concise; use bullet points for multi-part answers." | |
| ) | |
| def generate_answer( | |
| query: str, | |
| context: str, | |
| chat_history: Optional[List[Dict]] = None, | |
| ) -> str: | |
| """ | |
| Call Gemini with retrieved context + optional chat history. | |
| Retries up to MAX_RETRIES times with exponential backoff on 429 errors. | |
| """ | |
| model = _get_gemini() | |
| history_text = "" | |
| if chat_history: | |
| recent = chat_history[-(MAX_HISTORY_TURNS * 2):] | |
| for turn in recent: | |
| role = "User" if turn["role"] == "user" else "Assistant" | |
| content = turn["content"][:HISTORY_MSG_CAP] | |
| history_text += f"{role}: {content}\n" | |
| sections = [f"CONTEXT:\n{context}"] | |
| if history_text: | |
| sections.append(f"RECENT CHAT:\n{history_text.strip()}") | |
| sections.append(f"QUESTION: {query}") | |
| prompt = SYSTEM_PROMPT + "\n\n" + "\n\n".join(sections) + "\n\nANSWER:" | |
| estimated_tokens = len(prompt) // 4 | |
| logger.info(f"Prompt estimate: ~{estimated_tokens} tokens | model: {GEMINI_MODEL}") | |
| generation_config = genai.GenerationConfig( | |
| max_output_tokens=MAX_OUTPUT_TOKENS, | |
| temperature=TEMPERATURE, | |
| ) | |
| # retry | |
| last_error = None | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| response = model.generate_content(prompt, generation_config=generation_config) | |
| return response.text.strip() | |
| except Exception as e: | |
| last_error = e | |
| err_str = str(e) | |
| is_rate_limit = "429" in err_str or "quota" in err_str.lower() | |
| is_transient = any( | |
| code in err_str for code in ("500", "502", "503", "504") | |
| ) | |
| if (is_rate_limit or is_transient) and attempt < MAX_RETRIES - 1: | |
| wait = RETRY_BASE_WAIT * (2 ** attempt) # 1s → 2s → 4s | |
| logger.warning( | |
| f"Gemini {'rate limit' if is_rate_limit else 'transient error'} " | |
| f"(attempt {attempt + 1}/{MAX_RETRIES}). " | |
| f"Retrying in {wait:.0f}s… [{err_str[:80]}]" | |
| ) | |
| time.sleep(wait) | |
| else: | |
| # Non-retryable error or out of attempts | |
| logger.error(f"Gemini error after {attempt + 1} attempt(s): {e}") | |
| raise | |
| # Should never reach here, but satisfy type checker | |
| raise RuntimeError(f"Gemini call failed after {MAX_RETRIES} attempts: {last_error}") | |
| # Main RAG Pipeline | |
| def rag_query( | |
| query: str, | |
| document_id: str, | |
| chat_history: Optional[List[Dict]] = None, | |
| ) -> Dict: | |
| """ | |
| Full RAG pipeline: retrieve → quality gate → build context → generate answer. | |
| Args: | |
| query : user question | |
| document_id : MongoDB document identifier (used to filter chunks) | |
| chat_history: list of {"role": "user"|"assistant", "content": str} | |
| Returns: | |
| { | |
| "answer": str, | |
| "sources": [{"section": str, "pages": list, "score": float}], | |
| "chunks_used": int, | |
| } | |
| """ | |
| chunks = retrieve_chunks(query, document_id) | |
| context = build_context(chunks) | |
| answer = generate_answer(query, context, chat_history) | |
| sources = [ | |
| { | |
| "section": c.get("section", ""), | |
| "pages": c.get("pages", []), | |
| "score": round(c.get("score", 0), 3), | |
| } | |
| for c in chunks | |
| ] | |
| return { | |
| "answer": answer, | |
| "sources": sources, | |
| "chunks_used": len(chunks), | |
| } |