import os from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from groq import AsyncGroq import json import re from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer, CrossEncoder # Added CrossEncoder from sklearn.preprocessing import MinMaxScaler import numpy as np from typing import Any, List, Tuple import asyncio import torch import time # --- Configuration (can be overridden by the calling app) --- CHUNK_SIZE = 1000 CHUNK_OVERLAP = 200 TOP_K_CHUNKS = 10 # The final number of chunks to send to the LLM # A larger number of initial candidates for reranking INITIAL_K_CANDIDATES = 20 GROQ_MODEL_NAME = "openai/gpt-oss-20b" HYDE_MODEL = "meta-llama/Llama-4-Scout-17B-16E-Instruct" EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # --- Hypothetical Document Generation and EmbeddingClient remain unchanged --- async def generate_hypothetical_document(query: str, groq_api_key: str) -> str: """ Generates a hypothetical document using the Groq API. This prompt is generic and does not require prior knowledge of the document style. """ if not groq_api_key: print("Groq API key not set. Skipping hypothetical document generation.") return "" print(f"Starting HyDE generation for query: '{query}'...") client = AsyncGroq(api_key=groq_api_key) prompt = ( f"You are a document writer. Your task is to write a brief passage as a section of a document " f"that could answer the following question. The passage should use specific terminology and " f"a formal tone, as if it were an excerpt from a larger document. Do not include the question, " f"and do not add any conversational text. The goal is to create a concise, semantically rich text " f"to guide a search engine to find similarly styled and detailed content.\n\n" f"Question: {query}\n\n" f"Hypothetical Section:" ) try: chat_completion = await client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=HYDE_MODEL, temperature=0.7, max_tokens=500, ) hyde_doc = chat_completion.choices[0].message.content print("Hypothetical document generated.") return hyde_doc except Exception as e: print(f"An error occurred during HyDE generation: {e}") return "" class EmbeddingClient: """A client for generating text embeddings using a local, open-source model.""" def __init__(self, model_name: str = EMBEDDING_MODEL_NAME): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self.model = SentenceTransformer(model_name, device=self.device) print(f"Sentence Transformer embedding client initialized ({model_name}) on {self.device}.") def get_embeddings(self, texts: List[str]) -> torch.Tensor: if not texts: return torch.tensor([]) print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...") embeddings = self.model.encode(texts, convert_to_tensor=True, show_progress_bar=False) print("Embeddings generated successfully.") return embeddings # --- Hybrid Search Manager Class --- class HybridSearchManager: """ Manages the initialization and execution of a hybrid search system combining BM25, dense vector search, and a fast reranker. """ def __init__(self, embedding_model_name: str = EMBEDDING_MODEL_NAME): self.bm25_model = None self.embedding_client = EmbeddingClient(model_name=embedding_model_name) self.document_chunks = [] self.document_embeddings = None # Initialize BGE reranker model self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2', device='cuda' if torch.cuda.is_available() else 'cpu') print("ms-marco-MiniLM-L6-v2 Reranker initialized.") async def initialize_models(self, documents: list[Document]): self.document_chunks = documents corpus = [doc.page_content for doc in documents] if not corpus: print("No documents to initialize. Skipping model setup.") return print("Initializing BM25 model...") tokenized_corpus = [doc.split(" ") for doc in corpus] self.bm25_model = BM25Okapi(tokenized_corpus) print("BM25 model initialized.") print(f"Computing and storing document embeddings on {self.embedding_client.device}...") self.document_embeddings = self.embedding_client.get_embeddings(corpus) print("Document embeddings computed.") async def retrieve_candidates(self, query: str, hyde_doc: str) -> List[dict]: """ Performs a HyDE-enhanced hybrid search to retrieve initial candidates without reranking. """ if self.bm25_model is None or self.document_embeddings is None: raise ValueError("Hybrid search models are not initialized. Call initialize_models first.") print(f"Performing hybrid search for candidate retrieval for query: '{query}'...") hyde_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query tokenized_query = query.split(" ") bm25_scores = self.bm25_model.get_scores(tokenized_query) query_embedding = self.embedding_client.get_embeddings([hyde_query]) from torch.nn.functional import cosine_similarity dense_scores = cosine_similarity(query_embedding, self.document_embeddings) dense_scores = dense_scores.cpu().numpy() if len(bm25_scores) == 0 or len(dense_scores) == 0: return [] scaler = MinMaxScaler() normalized_bm25_scores = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten() normalized_dense_scores = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten() combined_scores = 0.5 * normalized_bm25_scores + 0.5 * normalized_dense_scores ranked_indices = np.argsort(combined_scores)[::-1] top_initial_indices = ranked_indices[:INITIAL_K_CANDIDATES] retrieved_results = [] for idx in top_initial_indices: doc = self.document_chunks[idx] retrieved_results.append({ "content": doc.page_content, "document_metadata": doc.metadata, "initial_score": combined_scores[idx] # Optionally store the initial score }) print(f"Retrieved {len(retrieved_results)} initial candidates for reranking.") return retrieved_results async def rerank_results(self, query: str, retrieved_results: List[dict], top_k: int) -> List[dict]: """ Performs reranking on a list of retrieved candidate documents. """ if not retrieved_results: return [] print(f"Reranking {len(retrieved_results)} candidates for query: '{query}'...") start_time_rerank = time.perf_counter() rerank_input = [[query, chunk["content"]] for chunk in retrieved_results] rerank_scores = await asyncio.to_thread( self.reranker.predict, rerank_input, show_progress_bar=False ) end_time_rerank = time.perf_counter() rerank_time = end_time_rerank - start_time_rerank scored_results = list(zip(retrieved_results, rerank_scores)) scored_results.sort(key=lambda x: x[1], reverse=True) final_chunks = [] for res, score in scored_results[:top_k]: final_chunks.append({ "content": res["content"], "document_metadata": res["document_metadata"], "rerank_score": score }) print(f"Reranking completed in {rerank_time:.4f} seconds. Returning top {len(final_chunks)} chunks.") return final_chunks, rerank_time # --- Other helper functions (process_markdown_with_recursive_chunking, generate_answer_with_groq) remain unchanged --- def process_markdown_with_recursive_chunking( md_file_path: str, chunk_size: int, chunk_overlap: int) -> List[Document]: all_chunks = [] full_text = "" if not os.path.exists(md_file_path): print(f"Error: File not found at '{md_file_path}'") return [] if not os.path.isfile(md_file_path): print(f"Error: Path '{md_file_path}' is not a file.") return [] if not md_file_path.lower().endswith(".md"): print(f"Warning: File '{md_file_path}' does not have a .md extension.") try: with open(md_file_path, 'r', encoding='utf-8') as f: full_text = f.read() except Exception as e: print(f"Error reading file '{md_file_path}': {e}") return [] if not full_text: print("Input markdown file is empty.") return [] text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, is_separator_regex=False, ) chunks = text_splitter.split_text(full_text) for chunk in chunks: all_chunks.append(Document(page_content=chunk, metadata={"document_part": "Whole Document"})) print(f"Created {len(all_chunks)} chunks from the entire document.") return all_chunks async def generate_answer_with_groq(query: str, retrieved_results: list[dict], groq_api_key: str) -> str: """ Generates an answer using the Groq API based on the query and retrieved chunks' content. """ if not groq_api_key: return "Error: Groq API key is not set. Cannot generate answer." print("Generating answer with Groq API...") client = AsyncGroq(api_key= groq_api_key) context_parts = [] for i, res in enumerate(retrieved_results): content = res.get("content", "") metadata = res.get("document_metadata", {}) section_heading = metadata.get("section_heading", "N/A") document_part = metadata.get("document_part", "N/A") context_parts.append( f"--- Context Chunk {i+1} ---\n" f"Document Part: {document_part}\n" f"Section Heading: {section_heading}\n" f"Content: {content}\n" f"-------------------------" ) context = "\n\n".join(context_parts) prompt = ( f"You are an expert on the provided document. Your task is to answer the user's question " f"based only on the information given. Your answers should be brief, concise, and in a similar style to these examples:\n" f"- Yes, outpatient consultations and diagnostic tests are covered, provided they are medically necessary and fall within the specified sub-limits under the plan.\n" f"- The policy does not cover any expenses incurred during the first 30 days from the inception of the policy, except in the case of accidents.\n" f"- Room rent is covered up to a single private AC room per day unless otherwise specified in the policy schedule.\n" f"- Yes, the policy allows for mid-term inclusion of newly married spouses and newborn children, subject to notification and payment of additional premium within the stipulated time frame.\n" f"Do not mention or refer to the document or the context in your final answer. If the information required to answer the question is not available in the provided context, state that you do not have enough information.\n\n" f"Context:\n{context}\n\n" f"Question: {query}\n\n" f"Answer:" ) try: chat_completion = await client.chat.completions.create( messages=[ { "role": "user", "content": prompt, } ], model=GROQ_MODEL_NAME, temperature=0.7, max_tokens=500, ) answer = chat_completion.choices[0].message.content print("Answer generated successfully.") return answer except Exception as e: print(f"An error occurred during Groq API call: {e}") return "Could not generate an answer due to an API error."