# advanced_retrieval.py import os from typing import List, Dict, Any, Tuple from dotenv import load_dotenv from langchain.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_community.vectorstores import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader from langchain.schema import Document from langchain.load import dumps, loads from bs4.filter import SoupStrainer import numpy as np from sklearn.metrics.pairwise import cosine_similarity from operator import itemgetter import asyncio from sentence_transformers import CrossEncoder load_dotenv() class AdvancedRetriever: def __init__(self, link: str): self.link = link self.llm = ChatOpenAI(temperature=0) self.embeddings = OpenAIEmbeddings() self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # Load and process documents self._load_documents() self._create_vector_stores() def _load_documents(self): """Load and chunk documents with different strategies""" loader = WebBaseLoader( web_path=(self.link,), bs_kwargs=dict( parse_only=SoupStrainer( class_=("post-content", "post-title", "post-header") ) ) ) docs = loader.load() # Small chunks for precise retrieval small_splitter = RecursiveCharacterTextSplitter( chunk_size=200, chunk_overlap=50, ) self.small_chunks = small_splitter.split_documents(docs) # Large chunks for context large_splitter = RecursiveCharacterTextSplitter( chunk_size=800, chunk_overlap=100, ) self.large_chunks = large_splitter.split_documents(docs) # Medium chunks (original) medium_splitter = RecursiveCharacterTextSplitter( chunk_size=300, chunk_overlap=50, ) self.medium_chunks = medium_splitter.split_documents(docs) def _create_vector_stores(self): """Create vector stores for different chunk sizes""" self.small_vectorstore = Chroma.from_documents( documents=self.small_chunks, embedding=self.embeddings, collection_name="small_chunks" ) self.large_vectorstore = Chroma.from_documents( documents=self.large_chunks, embedding=self.embeddings, collection_name="large_chunks" ) self.medium_vectorstore = Chroma.from_documents( documents=self.medium_chunks, embedding=self.embeddings, collection_name="medium_chunks" ) class MultiQueryRetrieval(AdvancedRetriever): """Generate multiple diverse queries and merge results""" def retrieve(self, question: str, k: int = 5) -> List[Document]: # Generate multiple query perspectives query_generation_prompt = ChatPromptTemplate.from_template(""" You are an AI assistant that generates multiple search queries from different perspectives. Generate 4 diverse search queries that would help answer this question: {question} Focus on different aspects and use varied vocabulary. Each query should be on a separate line. """) generate_queries = ( query_generation_prompt | self.llm | StrOutputParser() | (lambda x: x.strip().split('\n')) ) queries = generate_queries.invoke({"question": question}) queries.append(question) # Include original query # Retrieve documents for each query all_docs = [] for query in queries: docs = self.medium_vectorstore.similarity_search(query, k=k) all_docs.extend(docs) # Remove duplicates and return top k return self._deduplicate_documents(all_docs)[:k] def _deduplicate_documents(self, docs: List[Document]) -> List[Document]: """Remove duplicate documents based on content similarity""" if not docs: return docs unique_docs = [docs[0]] for doc in docs[1:]: is_duplicate = False for unique_doc in unique_docs: if doc.page_content == unique_doc.page_content: is_duplicate = True break if not is_duplicate: unique_docs.append(doc) return unique_docs class ParentChildRetrieval(AdvancedRetriever): """Retrieve small chunks but return larger parent context""" def retrieve(self, question: str, k: int = 5) -> List[Document]: # Search with small chunks for precision small_docs = self.small_vectorstore.similarity_search(question, k=k*2) # Find corresponding large chunks (parents) parent_docs = [] for small_doc in small_docs: # Find the large chunk that contains this small chunk parent = self._find_parent_chunk(small_doc) if parent and parent not in parent_docs: parent_docs.append(parent) return parent_docs[:k] def _find_parent_chunk(self, small_doc: Document) -> Document: """Find the parent chunk that contains the small chunk""" small_content = small_doc.page_content for large_doc in self.large_chunks: if small_content in large_doc.page_content: return large_doc return small_doc # Fallback to small doc if no parent found class ContextualCompression(AdvancedRetriever): """Compress retrieved chunks to focus on relevant information""" def retrieve(self, question: str, k: int = 5) -> List[Document]: # Initial retrieval docs = self.medium_vectorstore.similarity_search(question, k=k*2) # Compress each document compression_prompt = ChatPromptTemplate.from_template(""" Given this question: {question} Extract only the most relevant information from this text that helps answer the question. Remove any irrelevant details while preserving key facts and context. Text: {text} Relevant extract: """) compressed_docs = [] for doc in docs: compressed_content = ( compression_prompt | self.llm | StrOutputParser() ).invoke({"question": question, "text": doc.page_content}) # Only keep if compression resulted in meaningful content if len(compressed_content.strip()) > 50: compressed_doc = Document( page_content=compressed_content, metadata=doc.metadata ) compressed_docs.append(compressed_doc) return compressed_docs[:k] class CrossEncoderReranking(AdvancedRetriever): """Use cross-encoder for better relevance scoring""" def retrieve(self, question: str, k: int = 5) -> List[Document]: # Initial retrieval with higher k initial_docs = self.medium_vectorstore.similarity_search(question, k=k*3) if not initial_docs: return [] # Prepare query-document pairs for cross-encoder query_doc_pairs = [] for doc in initial_docs: query_doc_pairs.append([question, doc.page_content]) # Get relevance scores scores = self.cross_encoder.predict(query_doc_pairs) # Sort documents by relevance score doc_score_pairs = list(zip(initial_docs, scores)) doc_score_pairs.sort(key=lambda x: x[1], reverse=True) # type: ignore # Return top k documents return [doc for doc, score in doc_score_pairs[:k]] class SemanticRouting(AdvancedRetriever): """Route queries to specialized retrievers based on query type""" def __init__(self, link: str): super().__init__(link) self.query_classifier_prompt = ChatPromptTemplate.from_template(""" Classify this query into one of these categories: 1. FACTUAL - Asking for specific facts, definitions, or data 2. CONCEPTUAL - Asking for explanations, processes, or how things work 3. COMPARATIVE - Comparing different concepts, methods, or approaches 4. ANALYTICAL - Requiring analysis, reasoning, or synthesis Query: {question} Respond with only the category name (FACTUAL, CONCEPTUAL, COMPARATIVE, or ANALYTICAL): """) def retrieve(self, question: str, k: int = 5) -> List[Document]: # Classify the query query_type = ( self.query_classifier_prompt | self.llm | StrOutputParser() ).invoke({"question": question}).strip() # Route to appropriate retrieval strategy if query_type == "FACTUAL": return self._factual_retrieval(question, k) elif query_type == "CONCEPTUAL": return self._conceptual_retrieval(question, k) elif query_type == "COMPARATIVE": return self._comparative_retrieval(question, k) else: # ANALYTICAL return self._analytical_retrieval(question, k) def _factual_retrieval(self, question: str, k: int) -> List[Document]: """Precise retrieval for factual queries""" return self.small_vectorstore.similarity_search(question, k=k) def _conceptual_retrieval(self, question: str, k: int) -> List[Document]: """Broader context for conceptual queries""" return self.large_vectorstore.similarity_search(question, k=k) def _comparative_retrieval(self, question: str, k: int) -> List[Document]: """Multi-aspect retrieval for comparative queries""" # Extract comparison terms comparison_prompt = ChatPromptTemplate.from_template(""" Extract the main concepts being compared in this question: {question} List them separated by commas: """) concepts = ( comparison_prompt | self.llm | StrOutputParser() ).invoke({"question": question}) all_docs = [] for concept in concepts.split(','): docs = self.medium_vectorstore.similarity_search(concept.strip(), k=k//2) all_docs.extend(docs) return self._deduplicate_documents(all_docs)[:k] def _analytical_retrieval(self, question: str, k: int) -> List[Document]: """Comprehensive retrieval for analytical queries""" # Use multi-query approach for comprehensive coverage multi_query = MultiQueryRetrieval(self.link) return multi_query.retrieve(question, k) def _deduplicate_documents(self, docs: List[Document]) -> List[Document]: """Remove duplicate documents""" unique_docs = [] seen_content = set() for doc in docs: if doc.page_content not in seen_content: unique_docs.append(doc) seen_content.add(doc.page_content) return unique_docs # Integration functions for your main app def get_answer_using_multi_query(link: str, question: str) -> str: """Multi-Query Retrieval implementation""" retriever = MultiQueryRetrieval(link) docs = retriever.retrieve(question) # Generate answer using retrieved docs template = """Answer the following question based on this context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(temperature=0) final_chain = ( prompt | llm | StrOutputParser() ) context = "\n\n".join([doc.page_content for doc in docs]) response = final_chain.invoke({"context": context, "question": question}) return response def get_answer_using_parent_child(link: str, question: str) -> str: """Parent-Child Retrieval implementation""" retriever = ParentChildRetrieval(link) docs = retriever.retrieve(question) template = """Answer the following question based on this context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(temperature=0) final_chain = ( prompt | llm | StrOutputParser() ) context = "\n\n".join([doc.page_content for doc in docs]) response = final_chain.invoke({"context": context, "question": question}) return response def get_answer_using_contextual_compression(link: str, question: str) -> str: """Contextual Compression implementation""" retriever = ContextualCompression(link) docs = retriever.retrieve(question) template = """Answer the following question based on this context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(temperature=0) final_chain = ( prompt | llm | StrOutputParser() ) context = "\n\n".join([doc.page_content for doc in docs]) response = final_chain.invoke({"context": context, "question": question}) return response def get_answer_using_cross_encoder(link: str, question: str) -> str: """Cross-Encoder Reranking implementation""" retriever = CrossEncoderReranking(link) docs = retriever.retrieve(question) template = """Answer the following question based on this context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(temperature=0) final_chain = ( prompt | llm | StrOutputParser() ) context = "\n\n".join([doc.page_content for doc in docs]) response = final_chain.invoke({"context": context, "question": question}) return response def get_answer_using_semantic_routing(link: str, question: str) -> str: """Semantic Routing implementation""" retriever = SemanticRouting(link) docs = retriever.retrieve(question) template = """Answer the following question based on this context: {context} Question: {question} """ prompt = ChatPromptTemplate.from_template(template) llm = ChatOpenAI(temperature=0) final_chain = ( prompt | llm | StrOutputParser() ) context = "\n\n".join([doc.page_content for doc in docs]) response = final_chain.invoke({"context": context, "question": question}) return response # Example usage # if __name__ == "__main__": # link = "https://lilianweng.github.io/posts/2023-06-23-agent/" # question = "What is task decomposition for LLM agents?" # # Test all advanced retrieval techniques # techniques = [ # ("Multi-Query Retrieval", get_answer_using_multi_query), # ("Parent-Child Retrieval", get_answer_using_parent_child), # ("Contextual Compression", get_answer_using_contextual_compression), # ("Cross-Encoder Reranking", get_answer_using_cross_encoder), # ("Semantic Routing", get_answer_using_semantic_routing), # ] # for name, func in techniques: # print(f"\n=== {name} ===") # try: # answer = func(link, question) # print(answer) # except Exception as e: # print(f"Error: {e}") # print("-" * 50)