Spaces:
Sleeping
Sleeping
| """ | |
| Gemini File Search Chatbot (Beta Version) | |
| This chatbot uses Google Gemini File Search API for RAG. | |
| It provides a simpler architecture: Main Agent + Gemini Agent | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| import traceback | |
| from pathlib import Path | |
| from typing import Dict, List, Any, Optional, TypedDict | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from src.config.loader import load_config | |
| from src.llm.adapters import get_llm_client | |
| from src.config.paths import CONVERSATIONS_DIR | |
| from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class GeminiState(TypedDict): | |
| """State for Gemini chatbot conversation flow""" | |
| conversation_id: str | |
| messages: List[Any] | |
| current_query: str | |
| query_context: Optional[Dict[str, Any]] | |
| gemini_result: Optional[GeminiFileSearchResult] | |
| final_response: Optional[str] | |
| agent_logs: List[str] | |
| conversation_context: Dict[str, Any] | |
| session_start_time: float | |
| last_ai_message_time: float | |
| filters: Optional[Dict[str, Any]] | |
| class GeminiRAGChatbot: | |
| """Gemini File Search RAG chatbot (Beta version)""" | |
| def __init__(self, config_path: str = "src/config/settings.yaml"): | |
| """Initialize the Gemini chatbot""" | |
| logger.info("π€ INITIALIZING: Gemini File Search Chatbot (Beta)") | |
| self.config = load_config(config_path) | |
| # Get LLM provider from config | |
| reader_config = self.config.get("reader", {}) | |
| default_type = reader_config.get("default_type", "INF_PROVIDERS") | |
| provider_name = default_type.lower() | |
| self.llm_adapter = get_llm_client(provider_name, self.config) | |
| # Initialize Gemini File Search client | |
| try: | |
| self.gemini_client = GeminiFileSearchClient() | |
| logger.info("β Gemini File Search client initialized") | |
| except Exception as e: | |
| logger.error(f"β Failed to initialize Gemini client: {e}") | |
| raise RuntimeError(f"Gemini client initialization failed: {e}") | |
| # Build the LangGraph with LangSmith tracing if enabled | |
| self.graph = self._build_graph() | |
| # Enable LangSmith tracing if configured | |
| langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true" | |
| if langsmith_enabled: | |
| logger.info("π LangSmith tracing enabled") | |
| langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot") | |
| logger.info(f"π LangSmith project: {langsmith_project}") | |
| # Conversations directory | |
| self.conversations_dir = CONVERSATIONS_DIR | |
| try: | |
| self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| except (PermissionError, OSError) as e: | |
| logger.warning(f"Could not create conversations directory: {e}") | |
| self.conversations_dir = Path("conversations") | |
| self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| logger.info("β Gemini File Search Chatbot initialized") | |
| def _build_graph(self) -> StateGraph: | |
| """Build the LangGraph for Gemini chatbot""" | |
| graph = StateGraph(GeminiState) | |
| # Add nodes | |
| graph.add_node("main_agent", self._main_agent) | |
| graph.add_node("gemini_agent", self._gemini_agent) | |
| # Define the flow | |
| graph.set_entry_point("main_agent") | |
| graph.add_edge("main_agent", "gemini_agent") | |
| graph.add_edge("gemini_agent", END) | |
| return graph.compile() | |
| def _main_agent(self, state: GeminiState) -> GeminiState: | |
| """Main Agent: Extracts filters and prepares query""" | |
| logger.info("π― MAIN AGENT: Processing query") | |
| query = state["current_query"] | |
| messages = state["messages"] | |
| # Extract UI filters if present in query | |
| ui_filters = self._extract_ui_filters(query) | |
| # Extract context from conversation | |
| context = self._extract_context_from_conversation(messages, ui_filters) | |
| # Store context and filters | |
| state["query_context"] = context | |
| state["filters"] = context.get("filters", {}) | |
| logger.info(f"π― MAIN AGENT: Filters extracted: {state['filters']}") | |
| return state | |
| def _gemini_agent(self, state: GeminiState) -> GeminiState: | |
| """Gemini Agent: Performs file search and generates response""" | |
| logger.info("π GEMINI AGENT: Starting file search") | |
| query = state["current_query"] | |
| filters = state.get("filters", {}) | |
| # Perform Gemini file search | |
| try: | |
| result = self.gemini_client.search(query=query, filters=filters) | |
| logger.info(f"β GEMINI AGENT: Search completed, {len(result.sources)} sources found") | |
| # Enhance response with document references | |
| enhanced_response = self._enhance_response_with_references( | |
| result.answer, | |
| result.sources, | |
| query | |
| ) | |
| state["gemini_result"] = result | |
| state["final_response"] = enhanced_response | |
| state["last_ai_message_time"] = time.time() | |
| state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources") | |
| except Exception as e: | |
| logger.error(f"β GEMINI AGENT ERROR: {e}") | |
| traceback.print_exc() | |
| state["final_response"] = "I apologize, but I encountered an error while searching. Please try again." | |
| state["last_ai_message_time"] = time.time() | |
| return state | |
| def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str: | |
| """Enhance Gemini response to include document references and format nicely""" | |
| if not sources or not answer: | |
| return answer | |
| # Use LLM to intelligently add document references and format nicely | |
| try: | |
| llm = self.llm_adapter | |
| # Prepare document summaries for the LLM | |
| doc_summaries = [] | |
| for idx, doc in enumerate(sources, 1): | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {}) | |
| content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '') | |
| filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown' | |
| year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown' | |
| source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown' | |
| district = metadata.get('district', '') if isinstance(metadata, dict) else '' | |
| doc_info = f"{filename}" | |
| if year and year != 'Unknown': | |
| doc_info += f" ({year})" | |
| if source and source != 'Unknown': | |
| doc_info += f" - {source}" | |
| if district: | |
| doc_info += f" - {district}" | |
| doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...") | |
| prompt = f"""You are enhancing a response from a document search system. The original response is: | |
| {answer} | |
| The following documents were retrieved and used to generate this response: | |
| {chr(10).join(doc_summaries)} | |
| CRITICAL RULES: | |
| 1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate | |
| 2. The response should ONLY contain information from the retrieved documents listed above | |
| 3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information | |
| 4. Add document references [Doc i] at the end of sentences that use information from specific documents | |
| 5. Only reference documents that are actually used in the response | |
| 6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it | |
| 7. Keep the response natural, conversational, and well-formatted | |
| 8. Use proper formatting: paragraphs, line breaks, and structure for readability | |
| 9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting | |
| 10. If multiple documents support the same claim, use [Doc i, Doc j] format | |
| 11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents." | |
| Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary.""" | |
| enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt)) | |
| # Fallback: if LLM fails, just return original with basic formatting | |
| if not enhanced or len(enhanced) < len(answer) * 0.5: | |
| logger.warning("LLM enhancement failed, using original response with basic formatting") | |
| # Basic formatting: add line breaks after periods for readability | |
| formatted = answer.replace('. ', '.\n\n') | |
| if sources: | |
| ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))]) | |
| formatted += f"\n\n*Based on documents: {ref_list}*" | |
| return formatted | |
| return enhanced | |
| except Exception as e: | |
| logger.warning(f"Failed to enhance response with references: {e}") | |
| # Fallback: add basic formatting and references at the end | |
| formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting | |
| if sources: | |
| ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))]) | |
| formatted += f"\n\n*Based on documents: {ref_list}*" | |
| return formatted | |
| def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]: | |
| """Extract UI filters from query if present""" | |
| filters = {} | |
| if "FILTER CONTEXT:" in query: | |
| filter_section = query.split("FILTER CONTEXT:")[1] | |
| if "USER QUERY:" in filter_section: | |
| filter_section = filter_section.split("USER QUERY:")[0] | |
| filter_section = filter_section.strip() | |
| if "Sources:" in filter_section: | |
| sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')] | |
| if sources_line: | |
| sources_str = sources_line[0].split("Sources:")[1].strip() | |
| if sources_str and sources_str != "None": | |
| filters["sources"] = [s.strip() for s in sources_str.split(",")] | |
| if "Years:" in filter_section: | |
| years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')] | |
| if years_line: | |
| years_str = years_line[0].split("Years:")[1].strip() | |
| if years_str and years_str != "None": | |
| filters["year"] = [y.strip() for y in years_str.split(",")] | |
| if "Districts:" in filter_section: | |
| districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')] | |
| if districts_line: | |
| districts_str = districts_line[0].split("Districts:")[1].strip() | |
| if districts_str and districts_str != "None": | |
| filters["district"] = [d.strip() for d in districts_str.split(",")] | |
| if "Filenames:" in filter_section: | |
| filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')] | |
| if filenames_line: | |
| filenames_str = filenames_line[0].split("Filenames:")[1].strip() | |
| if filenames_str and filenames_str != "None": | |
| filters["filenames"] = [f.strip() for f in filenames_str.split(",")] | |
| return filters | |
| def _extract_context_from_conversation( | |
| self, | |
| messages: List[Any], | |
| ui_filters: Dict[str, List[str]] | |
| ) -> Dict[str, Any]: | |
| """Extract context from conversation history""" | |
| # Use UI filters if available | |
| filters = ui_filters.copy() if ui_filters else {} | |
| # For Gemini, we pass filters directly to the search function | |
| # The filters will be used to add context to the query | |
| return { | |
| "filters": filters, | |
| "has_filters": bool(filters) | |
| } | |
| def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]: | |
| """Main chat interface""" | |
| logger.info(f"π¬ GEMINI CHAT: Processing '{user_input[:50]}...'") | |
| # Load conversation | |
| conversation_file = self.conversations_dir / f"{conversation_id}.json" | |
| conversation = self._load_conversation(conversation_file) | |
| # Add user message | |
| conversation["messages"].append(HumanMessage(content=user_input)) | |
| # Prepare state | |
| state = GeminiState( | |
| conversation_id=conversation_id, | |
| messages=conversation["messages"], | |
| current_query=user_input, | |
| query_context=None, | |
| gemini_result=None, | |
| final_response=None, | |
| agent_logs=[], | |
| conversation_context=conversation.get("context", {}), | |
| session_start_time=conversation["session_start_time"], | |
| last_ai_message_time=conversation["last_ai_message_time"], | |
| filters=None | |
| ) | |
| # Run graph | |
| final_state = self.graph.invoke(state) | |
| # Add AI response to conversation | |
| if final_state["final_response"]: | |
| conversation["messages"].append(AIMessage(content=final_state["final_response"])) | |
| # Update conversation | |
| conversation["last_ai_message_time"] = final_state["last_ai_message_time"] | |
| conversation["context"] = final_state["conversation_context"] | |
| # Save conversation | |
| self._save_conversation(conversation_file, conversation) | |
| # Format sources for display | |
| sources = [] | |
| gemini_result = final_state.get("gemini_result") | |
| if gemini_result: | |
| sources = self.gemini_client.format_sources_for_display(gemini_result) | |
| logger.info(f"π GEMINI CHAT: Formatted {len(sources)} sources for display") | |
| return { | |
| 'response': final_state["final_response"] or "I apologize, but I couldn't process your request.", | |
| 'rag_result': { | |
| 'sources': sources, | |
| 'answer': final_state["final_response"] | |
| }, | |
| 'agent_logs': final_state["agent_logs"], | |
| 'actual_rag_query': final_state["current_query"], | |
| 'gemini_result': gemini_result # Include raw result for tracking | |
| } | |
| def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]: | |
| """Load conversation from file""" | |
| if conversation_file.exists(): | |
| try: | |
| with open(conversation_file) as f: | |
| data = json.load(f) | |
| messages = [] | |
| for msg_data in data.get("messages", []): | |
| if msg_data["type"] == "human": | |
| messages.append(HumanMessage(content=msg_data["content"])) | |
| elif msg_data["type"] == "ai": | |
| messages.append(AIMessage(content=msg_data["content"])) | |
| data["messages"] = messages | |
| return data | |
| except Exception as e: | |
| logger.warning(f"Could not load conversation: {e}") | |
| return { | |
| "messages": [], | |
| "session_start_time": time.time(), | |
| "last_ai_message_time": time.time(), | |
| "context": {} | |
| } | |
| def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]): | |
| """Save conversation to file""" | |
| try: | |
| conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True) | |
| messages_data = [] | |
| for msg in conversation["messages"]: | |
| if isinstance(msg, HumanMessage): | |
| messages_data.append({"type": "human", "content": msg.content}) | |
| elif isinstance(msg, AIMessage): | |
| messages_data.append({"type": "ai", "content": msg.content}) | |
| conversation_data = { | |
| "messages": messages_data, | |
| "session_start_time": conversation["session_start_time"], | |
| "last_ai_message_time": conversation["last_ai_message_time"], | |
| "context": conversation.get("context", {}) | |
| } | |
| with open(conversation_file, 'w') as f: | |
| json.dump(conversation_data, f, indent=2) | |
| except Exception as e: | |
| logger.error(f"Could not save conversation: {e}") | |
| def get_gemini_chatbot(): | |
| """Get Gemini chatbot instance""" | |
| return GeminiRAGChatbot() | |