""" DocMind - Multi-Agent System Implements Retriever, Reader, Critic, and Synthesizer agents """ from typing import List, Dict, Tuple from retriever import PaperRetriever import os class RetrieverAgent: """Agent responsible for finding relevant papers""" def __init__(self, retriever: PaperRetriever): self.retriever = retriever def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[Dict, float]]: """ Retrieve relevant papers for the query Returns: List of (paper, relevance_score) tuples """ print(f"🔍 Retriever Agent: Searching for '{query}'...") results = self.retriever.search(query, top_k) print(f" Found {len(results)} relevant papers") return results class ReaderAgent: """Agent responsible for reading and summarizing papers""" def __init__(self, llm_client=None): """ Args: llm_client: Optional LLM client (OpenAI, Anthropic, etc.) If None, uses rule-based summarization """ self.llm_client = llm_client def summarize_paper(self, paper: Dict) -> str: """ Generate a summary of a single paper Args: paper: Paper dictionary with title, abstract, etc. Returns: Summary string """ if self.llm_client: return self._llm_summarize(paper) else: return self._rule_based_summarize(paper) def _rule_based_summarize(self, paper: Dict) -> str: """Simple extractive summary (first 3 sentences)""" abstract = paper['abstract'] sentences = abstract.split('. ') summary = '. '.join(sentences[:3]) + '.' return { 'title': paper['title'], 'arxiv_id': paper['arxiv_id'], 'authors': paper['authors'][:3], 'summary': summary, 'year': paper['published'][:4] } def _llm_summarize(self, paper: Dict) -> str: """Use LLM to generate intelligent summary""" prompt = f"""Summarize this research paper in 2-3 sentences, focusing on: 1. The main contribution/idea 2. The key methodology or approach 3. Important results or implications Title: {paper['title']} Abstract: {paper['abstract']} Summary:""" # Call LLM (implementation depends on client) # This is a placeholder - replace with actual LLM call response = "LLM summary would go here" return { 'title': paper['title'], 'arxiv_id': paper['arxiv_id'], 'authors': paper['authors'][:3], 'summary': response, 'year': paper['published'][:4] } def read_papers(self, papers: List[Tuple[Dict, float]]) -> List[Dict]: """ Read and summarize multiple papers Args: papers: List of (paper, score) tuples from retriever Returns: List of summaries """ print(f"📖 Reader Agent: Reading {len(papers)} papers...") summaries = [] for paper, score in papers: summary = self.summarize_paper(paper) summary['relevance_score'] = score summaries.append(summary) print(f" Generated {len(summaries)} summaries") return summaries class CriticAgent: """Agent responsible for evaluating and filtering summaries""" def __init__(self, llm_client=None): self.llm_client = llm_client def critique(self, summaries: List[Dict], query: str) -> List[Dict]: """ Evaluate summaries for quality and relevance Args: summaries: List of paper summaries query: Original user query Returns: Filtered and scored summaries """ print(f"🔎 Critic Agent: Evaluating {len(summaries)} summaries...") # Simple rule-based filtering filtered = [] for summary in summaries: # Check relevance score threshold if summary['relevance_score'] > 0.3: # Add quality score (can be enhanced with LLM) summary['quality_score'] = self._assess_quality(summary, query) filtered.append(summary) # Sort by combined score filtered.sort( key=lambda x: x['relevance_score'] * 0.7 + x['quality_score'] * 0.3, reverse=True ) print(f" Retained {len(filtered)} high-quality summaries") return filtered def _assess_quality(self, summary: Dict, query: str) -> float: """ Simple quality assessment (can be enhanced with LLM) Returns: Quality score 0-1 """ score = 0.5 # Base score # Longer summaries might be more informative if len(summary['summary']) > 100: score += 0.2 # Recent papers get bonus if int(summary['year']) >= 2024: score += 0.3 return min(score, 1.0) class SynthesizerAgent: """Agent responsible for synthesizing final answer""" def __init__(self, llm_client=None): self.llm_client = llm_client def synthesize( self, summaries: List[Dict], query: str, max_papers: int = 10 ) -> str: """ Synthesize final answer from summaries Args: summaries: List of filtered, quality summaries query: Original user query max_papers: Maximum papers to include in response Returns: Final synthesized response with citations """ print(f"✨ Synthesizer Agent: Creating final response...") if not summaries: return "No relevant papers found for your query." # Limit to top papers top_summaries = summaries[:max_papers] if self.llm_client: return self._llm_synthesize(top_summaries, query) else: return self._rule_based_synthesize(top_summaries, query) def _rule_based_synthesize(self, summaries: List[Dict], query: str) -> str: """Create structured response without LLM""" response = f"# Research Summary: {query}\n\n" response += f"Based on {len(summaries)} relevant papers from arXiv:\n\n" for i, summary in enumerate(summaries, 1): response += f"## [{i}] {summary['title']}\n" response += f"**Authors:** {', '.join(summary['authors'])}" if len(summary['authors']) >= 3: response += " et al." response += f"\n**Year:** {summary['year']}\n" response += f"**arXiv ID:** {summary['arxiv_id']}\n" response += f"**Relevance:** {summary['relevance_score']:.2f}\n\n" response += f"{summary['summary']}\n\n" response += "---\n\n" return response def _llm_synthesize(self, summaries: List[Dict], query: str) -> str: """Use LLM to create coherent synthesis""" # Build context from summaries context = "" for i, summary in enumerate(summaries, 1): context += f"[{i}] {summary['title']} ({summary['arxiv_id']})\n" context += f" {summary['summary']}\n\n" prompt = f"""You are a research assistant. Based on the following papers, answer this question: Question: {query} Papers: {context} Provide a comprehensive answer that: 1. Directly addresses the question 2. Synthesizes information across papers 3. Cites papers by number [1], [2], etc. 4. Highlights key findings and consensus/disagreements 5. Is concise but thorough (3-5 paragraphs) Answer:""" # Placeholder for LLM call response = "LLM-generated synthesis would go here with citations" # Append paper references response += "\n\n## References\n" for i, summary in enumerate(summaries, 1): response += f"[{i}] {summary['title']} " response += f"({summary['arxiv_id']}, {summary['year']})\n" return response class DocMindOrchestrator: """Main orchestrator that coordinates all agents""" def __init__( self, retriever: PaperRetriever, llm_client=None ): self.retriever_agent = RetrieverAgent(retriever) self.reader_agent = ReaderAgent(llm_client) self.critic_agent = CriticAgent(llm_client) self.synthesizer_agent = SynthesizerAgent(llm_client) def process_query( self, query: str, top_k: int = 10, max_papers_in_response: int = 5 ) -> str: """ Process user query through full agent pipeline Args: query: User question top_k: Number of papers to retrieve max_papers_in_response: Max papers in final response Returns: Final synthesized answer """ print(f"\n{'=' * 60}") print(f"Processing query: {query}") print('=' * 60) # Step 1: Retrieve papers = self.retriever_agent.retrieve(query, top_k) if not papers: return "No relevant papers found for your query." # Step 2: Read & Summarize summaries = self.reader_agent.read_papers(papers) # Step 3: Critique & Filter quality_summaries = self.critic_agent.critique(summaries, query) # Step 4: Synthesize final_response = self.synthesizer_agent.synthesize( quality_summaries, query, max_papers_in_response ) print(f"{'=' * 60}\n") return final_response def main(): """Example usage of multi-agent system""" from fetch_arxiv_data import ArxivFetcher # Setup fetcher = ArxivFetcher() retriever = PaperRetriever() # Load or build index if not retriever.load_index(): papers = fetcher.load_papers("arxiv_papers.json") retriever.build_index(papers) retriever.save_index() # Create orchestrator orchestrator = DocMindOrchestrator(retriever) # Test queries test_queries = [ "What are the latest improvements in diffusion models?", "How does RLHF compare to DPO for language model alignment?", "What are the main challenges in scaling transformers?" ] for query in test_queries: response = orchestrator.process_query(query, top_k=8, max_papers_in_response=3) print(response) print("\n" + "=" * 80 + "\n") if __name__ == "__main__": main()