audit_assistant / src /agents /gemini_chatbot.py
akryldigital's picture
add gemini refferences
fa33a8f verified
"""
Gemini File Search Chatbot (Beta Version)
This chatbot uses Google Gemini File Search API for RAG.
It provides a simpler architecture: Main Agent + Gemini Agent
"""
import os
import json
import time
import logging
import traceback
from pathlib import Path
from typing import Dict, List, Any, Optional, TypedDict
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from src.config.loader import load_config
from src.llm.adapters import get_llm_client
from src.config.paths import CONVERSATIONS_DIR
from src.gemini.file_search import GeminiFileSearchClient, GeminiFileSearchResult
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class GeminiState(TypedDict):
"""State for Gemini chatbot conversation flow"""
conversation_id: str
messages: List[Any]
current_query: str
query_context: Optional[Dict[str, Any]]
gemini_result: Optional[GeminiFileSearchResult]
final_response: Optional[str]
agent_logs: List[str]
conversation_context: Dict[str, Any]
session_start_time: float
last_ai_message_time: float
filters: Optional[Dict[str, Any]]
class GeminiRAGChatbot:
"""Gemini File Search RAG chatbot (Beta version)"""
def __init__(self, config_path: str = "src/config/settings.yaml"):
"""Initialize the Gemini chatbot"""
logger.info("πŸ€– INITIALIZING: Gemini File Search Chatbot (Beta)")
self.config = load_config(config_path)
# Get LLM provider from config
reader_config = self.config.get("reader", {})
default_type = reader_config.get("default_type", "INF_PROVIDERS")
provider_name = default_type.lower()
self.llm_adapter = get_llm_client(provider_name, self.config)
# Initialize Gemini File Search client
try:
self.gemini_client = GeminiFileSearchClient()
logger.info("βœ… Gemini File Search client initialized")
except Exception as e:
logger.error(f"❌ Failed to initialize Gemini client: {e}")
raise RuntimeError(f"Gemini client initialization failed: {e}")
# Build the LangGraph with LangSmith tracing if enabled
self.graph = self._build_graph()
# Enable LangSmith tracing if configured
langsmith_enabled = os.getenv("LANGCHAIN_TRACING_V2", "false").lower() == "true"
if langsmith_enabled:
logger.info("πŸ” LangSmith tracing enabled")
langsmith_project = os.getenv("LANGCHAIN_PROJECT", "gemini-chatbot")
logger.info(f"πŸ“Š LangSmith project: {langsmith_project}")
# Conversations directory
self.conversations_dir = CONVERSATIONS_DIR
try:
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
except (PermissionError, OSError) as e:
logger.warning(f"Could not create conversations directory: {e}")
self.conversations_dir = Path("conversations")
self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
logger.info("βœ… Gemini File Search Chatbot initialized")
def _build_graph(self) -> StateGraph:
"""Build the LangGraph for Gemini chatbot"""
graph = StateGraph(GeminiState)
# Add nodes
graph.add_node("main_agent", self._main_agent)
graph.add_node("gemini_agent", self._gemini_agent)
# Define the flow
graph.set_entry_point("main_agent")
graph.add_edge("main_agent", "gemini_agent")
graph.add_edge("gemini_agent", END)
return graph.compile()
def _main_agent(self, state: GeminiState) -> GeminiState:
"""Main Agent: Extracts filters and prepares query"""
logger.info("🎯 MAIN AGENT: Processing query")
query = state["current_query"]
messages = state["messages"]
# Extract UI filters if present in query
ui_filters = self._extract_ui_filters(query)
# Extract context from conversation
context = self._extract_context_from_conversation(messages, ui_filters)
# Store context and filters
state["query_context"] = context
state["filters"] = context.get("filters", {})
logger.info(f"🎯 MAIN AGENT: Filters extracted: {state['filters']}")
return state
def _gemini_agent(self, state: GeminiState) -> GeminiState:
"""Gemini Agent: Performs file search and generates response"""
logger.info("πŸ” GEMINI AGENT: Starting file search")
query = state["current_query"]
filters = state.get("filters", {})
# Perform Gemini file search
try:
result = self.gemini_client.search(query=query, filters=filters)
logger.info(f"βœ… GEMINI AGENT: Search completed, {len(result.sources)} sources found")
# Enhance response with document references
enhanced_response = self._enhance_response_with_references(
result.answer,
result.sources,
query
)
state["gemini_result"] = result
state["final_response"] = enhanced_response
state["last_ai_message_time"] = time.time()
state["agent_logs"].append(f"GEMINI AGENT: Found {len(result.sources)} sources")
except Exception as e:
logger.error(f"❌ GEMINI AGENT ERROR: {e}")
traceback.print_exc()
state["final_response"] = "I apologize, but I encountered an error while searching. Please try again."
state["last_ai_message_time"] = time.time()
return state
def _enhance_response_with_references(self, answer: str, sources: List[Any], query: str) -> str:
"""Enhance Gemini response to include document references and format nicely"""
if not sources or not answer:
return answer
# Use LLM to intelligently add document references and format nicely
try:
llm = self.llm_adapter
# Prepare document summaries for the LLM
doc_summaries = []
for idx, doc in enumerate(sources, 1):
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else (doc if isinstance(doc, dict) else {})
content = getattr(doc, 'page_content', '') if hasattr(doc, 'page_content') else (doc.get('content', '') if isinstance(doc, dict) else '')
filename = metadata.get('filename', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
year = metadata.get('year', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
source = metadata.get('source', 'Unknown') if isinstance(metadata, dict) else 'Unknown'
district = metadata.get('district', '') if isinstance(metadata, dict) else ''
doc_info = f"{filename}"
if year and year != 'Unknown':
doc_info += f" ({year})"
if source and source != 'Unknown':
doc_info += f" - {source}"
if district:
doc_info += f" - {district}"
doc_summaries.append(f"[Doc {idx}] {doc_info}: {content[:300]}...")
prompt = f"""You are enhancing a response from a document search system. The original response is:
{answer}
The following documents were retrieved and used to generate this response:
{chr(10).join(doc_summaries)}
CRITICAL RULES:
1. Format the response nicely with proper paragraphs, bullet points, or structured sections where appropriate
2. The response should ONLY contain information from the retrieved documents listed above
3. If the response mentions information NOT found in the retrieved documents, you must REMOVE or CORRECT that information
4. Add document references [Doc i] at the end of sentences that use information from specific documents
5. Only reference documents that are actually used in the response
6. If the response mentions years, sources, or data that don't match the retrieved documents, you must correct it
7. Keep the response natural, conversational, and well-formatted
8. Use proper formatting: paragraphs, line breaks, and structure for readability
9. Don't change the core content that matches the documents, just add references where appropriate and improve formatting
10. If multiple documents support the same claim, use [Doc i, Doc j] format
11. If the response contains information that cannot be verified in the retrieved documents, add a note like: "Note: This information may not be in the retrieved documents."
Return ONLY the enhanced, well-formatted response with references added and any corrections made. Do not include any explanation or meta-commentary."""
enhanced = llm.invoke(prompt).content if hasattr(llm.invoke(prompt), 'content') else str(llm.invoke(prompt))
# Fallback: if LLM fails, just return original with basic formatting
if not enhanced or len(enhanced) < len(answer) * 0.5:
logger.warning("LLM enhancement failed, using original response with basic formatting")
# Basic formatting: add line breaks after periods for readability
formatted = answer.replace('. ', '.\n\n')
if sources:
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
formatted += f"\n\n*Based on documents: {ref_list}*"
return formatted
return enhanced
except Exception as e:
logger.warning(f"Failed to enhance response with references: {e}")
# Fallback: add basic formatting and references at the end
formatted = answer.replace('. ', '.\n\n') # Basic paragraph formatting
if sources:
ref_list = ", ".join([f"[Doc {i+1}]" for i in range(min(len(sources), 5))])
formatted += f"\n\n*Based on documents: {ref_list}*"
return formatted
def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
"""Extract UI filters from query if present"""
filters = {}
if "FILTER CONTEXT:" in query:
filter_section = query.split("FILTER CONTEXT:")[1]
if "USER QUERY:" in filter_section:
filter_section = filter_section.split("USER QUERY:")[0]
filter_section = filter_section.strip()
if "Sources:" in filter_section:
sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')]
if sources_line:
sources_str = sources_line[0].split("Sources:")[1].strip()
if sources_str and sources_str != "None":
filters["sources"] = [s.strip() for s in sources_str.split(",")]
if "Years:" in filter_section:
years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')]
if years_line:
years_str = years_line[0].split("Years:")[1].strip()
if years_str and years_str != "None":
filters["year"] = [y.strip() for y in years_str.split(",")]
if "Districts:" in filter_section:
districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')]
if districts_line:
districts_str = districts_line[0].split("Districts:")[1].strip()
if districts_str and districts_str != "None":
filters["district"] = [d.strip() for d in districts_str.split(",")]
if "Filenames:" in filter_section:
filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')]
if filenames_line:
filenames_str = filenames_line[0].split("Filenames:")[1].strip()
if filenames_str and filenames_str != "None":
filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
return filters
def _extract_context_from_conversation(
self,
messages: List[Any],
ui_filters: Dict[str, List[str]]
) -> Dict[str, Any]:
"""Extract context from conversation history"""
# Use UI filters if available
filters = ui_filters.copy() if ui_filters else {}
# For Gemini, we pass filters directly to the search function
# The filters will be used to add context to the query
return {
"filters": filters,
"has_filters": bool(filters)
}
def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
"""Main chat interface"""
logger.info(f"πŸ’¬ GEMINI CHAT: Processing '{user_input[:50]}...'")
# Load conversation
conversation_file = self.conversations_dir / f"{conversation_id}.json"
conversation = self._load_conversation(conversation_file)
# Add user message
conversation["messages"].append(HumanMessage(content=user_input))
# Prepare state
state = GeminiState(
conversation_id=conversation_id,
messages=conversation["messages"],
current_query=user_input,
query_context=None,
gemini_result=None,
final_response=None,
agent_logs=[],
conversation_context=conversation.get("context", {}),
session_start_time=conversation["session_start_time"],
last_ai_message_time=conversation["last_ai_message_time"],
filters=None
)
# Run graph
final_state = self.graph.invoke(state)
# Add AI response to conversation
if final_state["final_response"]:
conversation["messages"].append(AIMessage(content=final_state["final_response"]))
# Update conversation
conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
conversation["context"] = final_state["conversation_context"]
# Save conversation
self._save_conversation(conversation_file, conversation)
# Format sources for display
sources = []
gemini_result = final_state.get("gemini_result")
if gemini_result:
sources = self.gemini_client.format_sources_for_display(gemini_result)
logger.info(f"πŸ“‹ GEMINI CHAT: Formatted {len(sources)} sources for display")
return {
'response': final_state["final_response"] or "I apologize, but I couldn't process your request.",
'rag_result': {
'sources': sources,
'answer': final_state["final_response"]
},
'agent_logs': final_state["agent_logs"],
'actual_rag_query': final_state["current_query"],
'gemini_result': gemini_result # Include raw result for tracking
}
def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
"""Load conversation from file"""
if conversation_file.exists():
try:
with open(conversation_file) as f:
data = json.load(f)
messages = []
for msg_data in data.get("messages", []):
if msg_data["type"] == "human":
messages.append(HumanMessage(content=msg_data["content"]))
elif msg_data["type"] == "ai":
messages.append(AIMessage(content=msg_data["content"]))
data["messages"] = messages
return data
except Exception as e:
logger.warning(f"Could not load conversation: {e}")
return {
"messages": [],
"session_start_time": time.time(),
"last_ai_message_time": time.time(),
"context": {}
}
def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
"""Save conversation to file"""
try:
conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
messages_data = []
for msg in conversation["messages"]:
if isinstance(msg, HumanMessage):
messages_data.append({"type": "human", "content": msg.content})
elif isinstance(msg, AIMessage):
messages_data.append({"type": "ai", "content": msg.content})
conversation_data = {
"messages": messages_data,
"session_start_time": conversation["session_start_time"],
"last_ai_message_time": conversation["last_ai_message_time"],
"context": conversation.get("context", {})
}
with open(conversation_file, 'w') as f:
json.dump(conversation_data, f, indent=2)
except Exception as e:
logger.error(f"Could not save conversation: {e}")
def get_gemini_chatbot():
"""Get Gemini chatbot instance"""
return GeminiRAGChatbot()