akryldigital commited on
Commit
9db763a
Β·
verified Β·
1 Parent(s): b632fe0

create base Agent factory

Browse files
src/agents/base_multi_agent_chatbot.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Multi-Agent Chatbot - Abstract base class with sophisticated query analysis
3
+
4
+ This module extracts the core multi-agent logic from MultiAgentRAGChatbot:
5
+ - Sophisticated LLM-based query analysis
6
+ - Filter extraction and validation
7
+ - Query rewriting
8
+ - Conversation management
9
+ - Main agent, RAG agent, Response agent logic
10
+
11
+ Subclasses only need to implement:
12
+ - _perform_retrieval(): The actual retrieval mechanism (text-based RAG vs visual search)
13
+ """
14
+
15
+ import re
16
+ import json
17
+ import time
18
+ import logging
19
+ import traceback
20
+ from pathlib import Path
21
+ from datetime import datetime
22
+ from dataclasses import dataclass
23
+ from typing import Dict, List, Any, Optional, TypedDict, Union
24
+ from abc import ABC, abstractmethod
25
+
26
+ from langgraph.graph import StateGraph, END
27
+ from langchain_core.prompts import ChatPromptTemplate
28
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
29
+
30
+ from src.llm.adapters import get_llm_client
31
+ from src.config.paths import PROJECT_DIR, CONVERSATIONS_DIR
32
+ from src.config.loader import load_config
33
+
34
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclass
39
+ class QueryContext:
40
+ """Context extracted from conversation"""
41
+ has_district: bool = False
42
+ has_source: bool = False
43
+ has_year: bool = False
44
+ extracted_district: Optional[Union[str, List[str]]] = None
45
+ extracted_source: Optional[Union[str, List[str]]] = None
46
+ extracted_year: Optional[Union[str, List[str]]] = None
47
+ ui_filters: Dict[str, List[str]] = None
48
+ confidence_score: float = 0.0
49
+ needs_follow_up: bool = False
50
+ follow_up_question: Optional[str] = None
51
+
52
+ def __post_init__(self):
53
+ self._process_multiple("extracted_source")
54
+ self._process_multiple("extracted_district")
55
+
56
+ def _process_multiple(self, key):
57
+ if isinstance(self.__dict__[key], list):
58
+ self.__dict__[key] = [d.title() for d in self.__dict__[key]]
59
+ else:
60
+ self.__dict__[key] = self.__dict__[key].title() if self.__dict__[key] else None
61
+
62
+
63
+
64
+ class MultiAgentState(TypedDict):
65
+ """State for the multi-agent conversation flow"""
66
+ conversation_id: str
67
+ messages: List[Any]
68
+ current_query: str
69
+ query_context: Optional[QueryContext]
70
+ rag_query: Optional[str]
71
+ rag_filters: Optional[Dict[str, Any]]
72
+ retrieved_documents: Optional[List[Any]]
73
+ final_response: Optional[str]
74
+ agent_logs: List[str]
75
+ conversation_context: Dict[str, Any]
76
+ session_start_time: float
77
+ last_ai_message_time: float
78
+
79
+
80
+ class BaseMultiAgentChatbot(ABC):
81
+ """
82
+ Abstract base class for multi-agent chatbots.
83
+
84
+ Provides all the sophisticated logic from MultiAgentRAGChatbot:
85
+ - LLM-based query analysis
86
+ - Filter extraction and validation
87
+ - Query rewriting
88
+ - Main agent, RAG agent, Response agent
89
+
90
+ Subclasses only need to implement:
91
+ - _perform_retrieval(): The actual retrieval mechanism
92
+ """
93
+
94
+ def __init__(self, config_path: str = "src/config/settings.yaml"):
95
+ """Initialize the base multi-agent chatbot"""
96
+ self.config = load_config(config_path)
97
+
98
+ # Get LLM provider from config
99
+ reader_config = self.config.get("reader", {})
100
+ default_type = reader_config.get("default_type", "INF_PROVIDERS")
101
+ provider_name = default_type.lower()
102
+
103
+ self.llm_adapter = get_llm_client(provider_name, self.config)
104
+
105
+ # Create LangChain-compatible wrapper
106
+ class LLMWrapper:
107
+ def __init__(self, adapter):
108
+ self.adapter = adapter
109
+
110
+ def invoke(self, messages):
111
+ if isinstance(messages, list):
112
+ formatted_messages = []
113
+ for msg in messages:
114
+ if hasattr(msg, 'content'):
115
+ role = "user" if msg.__class__.__name__ == "HumanMessage" else "assistant"
116
+ formatted_messages.append({"role": role, "content": msg.content})
117
+ else:
118
+ formatted_messages.append({"role": "user", "content": str(msg)})
119
+ else:
120
+ formatted_messages = [{"role": "user", "content": str(messages)}]
121
+
122
+ response = self.adapter.generate(formatted_messages)
123
+
124
+ class MockResponse:
125
+ def __init__(self, content):
126
+ self.content = content
127
+
128
+ return MockResponse(response.content)
129
+
130
+ self.llm = LLMWrapper(self.llm_adapter)
131
+
132
+ # Load dynamic data (filter options)
133
+ self._load_dynamic_data()
134
+
135
+ # Build the multi-agent graph
136
+ self.graph = self._build_graph()
137
+
138
+ # Conversations directory
139
+ self.conversations_dir = CONVERSATIONS_DIR
140
+ try:
141
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
142
+ except (PermissionError, OSError) as e:
143
+ logger.warning(f"Could not create conversations directory at {self.conversations_dir}: {e}")
144
+ self.conversations_dir = Path("conversations")
145
+ try:
146
+ self.conversations_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
147
+ except (PermissionError, OSError) as e2:
148
+ logger.error(f"Could not create conversations directory at {self.conversations_dir}: {e2}")
149
+ raise RuntimeError(f"Failed to create conversations directory: {e2}")
150
+
151
+ logger.info("πŸ€– Base Multi-Agent Chatbot initialized")
152
+
153
+ def _load_dynamic_data(self):
154
+ """Load dynamic data from filter_options.json"""
155
+ try:
156
+ fo = PROJECT_DIR / "src" / "config" / "filter_options.json"
157
+ if fo.exists():
158
+ with open(fo) as f:
159
+ data = json.load(f)
160
+ self.year_whitelist = [str(y).strip() for y in data.get("years", [])]
161
+ self.source_whitelist = [str(s).strip() for s in data.get("sources", [])]
162
+ self.district_whitelist = [str(d).strip() for d in data.get("districts", [])]
163
+ else:
164
+ self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
165
+ self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
166
+ self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
167
+ except Exception as e:
168
+ logger.warning(f"Could not load filter options: {e}")
169
+ self.year_whitelist = ['2018', '2019', '2020', '2021', '2022', '2023', '2024']
170
+ self.source_whitelist = ['Consolidated', 'Local Government', 'Ministry, Department and Agency']
171
+ self.district_whitelist = ['Kampala', 'Gulu', 'Kalangala']
172
+
173
+ # Enrich district list
174
+ try:
175
+ from add_district_metadata import DistrictMetadataProcessor
176
+ proc = DistrictMetadataProcessor()
177
+ names = set()
178
+ for key, mapping in proc.district_mappings.items():
179
+ if getattr(mapping, 'is_district', True):
180
+ names.add(mapping.name)
181
+ if names:
182
+ merged = list(self.district_whitelist)
183
+ for n in sorted(names):
184
+ if n not in merged:
185
+ merged.append(n)
186
+ self.district_whitelist = merged
187
+ logger.info(f"🧭 District whitelist enriched: {len(self.district_whitelist)} entries")
188
+ except Exception as e:
189
+ logger.info(f"ℹ️ Could not enrich districts: {e}")
190
+
191
+ # Calculate current year dynamically
192
+ self.current_year = str(datetime.now().year)
193
+ self.previous_year = str(datetime.now().year - 1)
194
+
195
+ logger.info(f"πŸ“Š ACTUAL FILTER VALUES:")
196
+ logger.info(f" Years: {self.year_whitelist}")
197
+ logger.info(f" Sources: {self.source_whitelist}")
198
+ logger.info(f" Districts: {len(self.district_whitelist)} districts (first 30: {self.district_whitelist[:30]})")
199
+
200
+ def _normalize_district_name(self, district: str) -> Optional[str]:
201
+ """Normalize district name with fuzzy matching - ALWAYS returns title case for Qdrant compatibility"""
202
+ if not district:
203
+ return None
204
+
205
+ district = district.strip()
206
+ district_title = district.title()
207
+
208
+ # Check if district exists in whitelist (case-insensitive)
209
+ district_lower = district.lower()
210
+ whitelist_lower = {d.lower(): d for d in self.district_whitelist}
211
+
212
+ # Direct match (case-insensitive) - always return title case
213
+ if district_lower in whitelist_lower:
214
+ return district_title # Return title case, not whitelist value
215
+
216
+ # Remove "District" suffix and try again
217
+ district_name = district.replace(" District", "").replace(" district", "").strip()
218
+ district_name_lower = district_name.lower()
219
+ district_name_title = district_name.title()
220
+
221
+ if district_name_lower in whitelist_lower:
222
+ return district_name_title # Return title case
223
+
224
+ # Common misspellings and abbreviations - return correct case
225
+ misspelling_map = {
226
+ "kalagala": "Kalangala",
227
+ "kalangala": "Kalangala",
228
+ "gulu": "Gulu",
229
+ "kampala": "Kampala",
230
+ "padr": "Pader",
231
+ "padre": "Pader",
232
+ "pader": "Pader",
233
+ "kcc": "Kcca", # Match whitelist format
234
+ "kcca": "Kcca", # Match whitelist format
235
+ "kimboga": "Kiboga",
236
+ "kiboga": "Kiboga",
237
+ "jinja": "Jinja",
238
+ "mbale": "Mbale",
239
+ "mbarara": "Mbarara",
240
+ "soroti": "Soroti",
241
+ "lira": "Lira",
242
+ "arua": "Arua",
243
+ "masaka": "Masaka",
244
+ "fort portal": "Fort Portal",
245
+ "fortportal": "Fort Portal",
246
+ }
247
+
248
+ if district_name_lower in misspelling_map:
249
+ return misspelling_map[district_name_lower] # Already title case
250
+
251
+ # Fuzzy matching (case-insensitive) - return title case
252
+ for whitelist_district in self.district_whitelist:
253
+ if district_name_lower == whitelist_district.lower():
254
+ return district_name_title # Return title case
255
+
256
+ if len(district_name) >= 4 and len(whitelist_district) >= 4:
257
+ if district_name_lower in whitelist_district.lower() or whitelist_district.lower() in district_name_lower:
258
+ min_len = min(len(district_name), len(whitelist_district))
259
+ max_len = max(len(district_name), len(whitelist_district))
260
+ if min_len / max_len >= 0.8:
261
+ return district_name_title # Return title case
262
+
263
+ # Last resort: if input looks valid, return title case anyway
264
+ # This handles cases where whitelist might be incomplete
265
+ if len(district_name) >= 3:
266
+ return district_name_title
267
+
268
+ return None
269
+
270
+ def _build_graph(self) -> StateGraph:
271
+ """Build the multi-agent LangGraph"""
272
+ graph = StateGraph(MultiAgentState)
273
+
274
+ # Add nodes for each agent
275
+ graph.add_node("main_agent", self._main_agent)
276
+ graph.add_node("rag_agent", self._rag_agent)
277
+ graph.add_node("response_agent", self._response_agent)
278
+
279
+ # Define the flow
280
+ graph.set_entry_point("main_agent")
281
+
282
+ # Main agent decides next step
283
+ graph.add_conditional_edges(
284
+ "main_agent",
285
+ self._should_call_rag,
286
+ {
287
+ "follow_up": END,
288
+ "call_rag": "rag_agent"
289
+ }
290
+ )
291
+
292
+ # RAG agent calls response agent
293
+ graph.add_edge("rag_agent", "response_agent")
294
+
295
+ # Response agent returns to main agent
296
+ graph.add_edge("response_agent", "main_agent")
297
+
298
+ return graph.compile()
299
+
300
+ def _should_call_rag(self, state: MultiAgentState) -> str:
301
+ """Determine if we should call RAG or ask follow-up"""
302
+ if state.get("final_response"):
303
+ return "follow_up"
304
+
305
+ context = state["query_context"]
306
+ if context and context.needs_follow_up:
307
+ return "follow_up"
308
+ return "call_rag"
309
+
310
+ def _main_agent(self, state: MultiAgentState) -> MultiAgentState:
311
+ """Main Agent: Handles conversation flow and follow-ups"""
312
+ logger.info("🎯 MAIN AGENT: Starting analysis")
313
+
314
+ if state.get("final_response"):
315
+ logger.info("🎯 MAIN AGENT: Final response already exists, ending")
316
+ return state
317
+
318
+ query = state["current_query"]
319
+ messages = state["messages"]
320
+
321
+ logger.info(f"🎯 MAIN AGENT: Extracting UI filters from query")
322
+ ui_filters = self._extract_ui_filters(query)
323
+ logger.info(f"🎯 MAIN AGENT: UI filters extracted: {ui_filters}")
324
+
325
+ # Analyze query context using LLM
326
+ logger.info(f"🎯 MAIN AGENT: Analyzing query context")
327
+ context = self._analyze_query_context(query, messages, ui_filters)
328
+
329
+ state["agent_logs"].append(f"MAIN AGENT: Context analyzed - district={context.has_district}, source={context.has_source}, year={context.has_year}")
330
+ logger.info(f"🎯 MAIN AGENT: Context analysis complete")
331
+
332
+ state["query_context"] = context
333
+
334
+ # If follow-up needed, generate response
335
+ if context.needs_follow_up:
336
+ logger.info(f"🎯 MAIN AGENT: Follow-up needed, generating question")
337
+ response = context.follow_up_question
338
+ state["final_response"] = response
339
+ state["last_ai_message_time"] = time.time()
340
+ else:
341
+ logger.info("🎯 MAIN AGENT: No follow-up needed, proceeding to RAG")
342
+
343
+ return state
344
+
345
+ def _rag_agent(self, state: MultiAgentState) -> MultiAgentState:
346
+ """RAG Agent: Rewrites queries and applies filters"""
347
+ logger.info("πŸ” RAG AGENT: Starting query rewriting and filter preparation")
348
+
349
+ context = state["query_context"]
350
+ messages = state["messages"]
351
+
352
+ # Rewrite query for RAG
353
+ logger.info(f"πŸ” RAG AGENT: Rewriting query for optimal retrieval")
354
+ rag_query = self._rewrite_query_for_rag(messages, context)
355
+ logger.info(f"πŸ” RAG AGENT: Query rewritten: '{rag_query}'")
356
+
357
+ # Build filters
358
+ logger.info(f"πŸ” RAG AGENT: Building filters from context: {context}")
359
+ filters = self._build_filters(context)
360
+ logger.info(f"πŸ” RAG AGENT: Filters built: {filters}")
361
+
362
+ state["agent_logs"].append(f"RAG AGENT: Query='{rag_query}', Filters={filters}")
363
+
364
+ state["rag_query"] = rag_query
365
+ state["rag_filters"] = filters
366
+
367
+ return state
368
+
369
+ def _response_agent(self, state: MultiAgentState) -> MultiAgentState:
370
+ """Response Agent: Generates final answer from retrieved documents"""
371
+ logger.info("πŸ“ RESPONSE AGENT: Starting document retrieval and answer generation")
372
+
373
+ rag_query = state["rag_query"]
374
+ filters = state["rag_filters"]
375
+
376
+ logger.info(f"πŸ“ RESPONSE AGENT: Calling retrieval with query: '{rag_query}'")
377
+ logger.info(f"πŸ“ RESPONSE AGENT: Using filters: {filters}")
378
+
379
+ try:
380
+ # Call subclass-specific retrieval method
381
+ result = self._perform_retrieval(rag_query, filters)
382
+
383
+ state["retrieved_documents"] = result.sources
384
+ state["agent_logs"].append(f"RESPONSE AGENT: Retrieved {len(result.sources)} documents")
385
+
386
+ logger.info(f"πŸ“ RESPONSE AGENT: Retrieved {len(result.sources)} documents")
387
+
388
+ # Check highest similarity score
389
+ highest_score = 0.0
390
+ if result.sources:
391
+ for doc in result.sources:
392
+ score = getattr(doc, 'metadata', {}).get('reranked_score') or getattr(doc, 'metadata', {}).get('original_score', 0.0) if hasattr(doc, 'metadata') else getattr(doc, 'score', 0.0)
393
+ if score > highest_score:
394
+ highest_score = score
395
+
396
+ logger.info(f"πŸ“ RESPONSE AGENT: Highest similarity score: {highest_score:.4f}")
397
+
398
+ # If highest score is too low, use LLM knowledge only
399
+ if highest_score <= 0.15:
400
+ logger.warning(f"⚠️ RESPONSE AGENT: Low similarity score, using LLM knowledge only")
401
+ response = self._generate_conversational_response_without_docs(
402
+ state["current_query"],
403
+ state["messages"]
404
+ )
405
+ else:
406
+ # Generate conversational response with documents
407
+ response = self._generate_conversational_response(
408
+ state["current_query"],
409
+ result.sources,
410
+ result.answer,
411
+ state["messages"],
412
+ filters # Pass filters for coverage validation
413
+ )
414
+
415
+ state["final_response"] = response
416
+ state["last_ai_message_time"] = time.time()
417
+
418
+ logger.info(f"πŸ“ RESPONSE AGENT: Answer generation complete")
419
+
420
+ except Exception as e:
421
+ logger.error(f"❌ RESPONSE AGENT ERROR: {e}")
422
+ traceback.print_exc()
423
+ state["final_response"] = "I apologize, but I encountered an error while retrieving information. Please try again."
424
+ state["last_ai_message_time"] = time.time()
425
+
426
+ return state
427
+
428
+ @abstractmethod
429
+ def _perform_retrieval(self, query: str, filters: Dict[str, Any]) -> Any:
430
+ """
431
+ Perform retrieval - must be implemented by subclasses.
432
+
433
+ Args:
434
+ query: The rewritten query
435
+ filters: The filters to apply
436
+
437
+ Returns:
438
+ Result object with .sources and .answer attributes
439
+ """
440
+ pass
441
+
442
+ def _extract_ui_filters(self, query: str) -> Dict[str, List[str]]:
443
+ """Extract UI filters from query"""
444
+ filters = {}
445
+
446
+ if "FILTER CONTEXT:" in query:
447
+ filter_section = query.split("FILTER CONTEXT:")[1]
448
+ if "USER QUERY:" in filter_section:
449
+ filter_section = filter_section.split("USER QUERY:")[0]
450
+ filter_section = filter_section.strip()
451
+
452
+ if "Sources:" in filter_section:
453
+ sources_line = [line for line in filter_section.split('\n') if line.strip().startswith('Sources:')][0]
454
+ sources_str = sources_line.split("Sources:")[1].strip()
455
+ if sources_str and sources_str != "None":
456
+ filters["sources"] = [s.strip() for s in sources_str.split(",")]
457
+
458
+ if "Years:" in filter_section:
459
+ years_line = [line for line in filter_section.split('\n') if line.strip().startswith('Years:')][0]
460
+ years_str = years_line.split("Years:")[1].strip()
461
+ if years_str and years_str != "None":
462
+ filters["years"] = [y.strip() for y in years_str.split(",")]
463
+
464
+ if "Districts:" in filter_section:
465
+ districts_line = [line for line in filter_section.split('\n') if line.strip().startswith('Districts:')][0]
466
+ districts_str = districts_line.split("Districts:")[1].strip()
467
+ if districts_str and districts_str != "None":
468
+ filters["districts"] = [d.strip() for d in districts_str.split(",")]
469
+
470
+ if "Filenames:" in filter_section:
471
+ filenames_line = [line for line in filter_section.split('\n') if line.strip().startswith('Filenames:')][0]
472
+ filenames_str = filenames_line.split("Filenames:")[1].strip()
473
+ if filenames_str and filenames_str != "None":
474
+ filters["filenames"] = [f.strip() for f in filenames_str.split(",")]
475
+
476
+ return filters
477
+
478
+ def _analyze_query_context(self, query: str, messages: List[Any], ui_filters: Dict[str, List[str]]) -> QueryContext:
479
+ """Analyze query context using LLM - EXACT COPY FROM v1"""
480
+ logger.info(f"πŸ” QUERY ANALYSIS: '{query[:50]}...' | UI filters: {ui_filters}")
481
+
482
+ # Build conversation context
483
+ conversation_context = ""
484
+ for msg in messages[-6:]:
485
+ if isinstance(msg, HumanMessage):
486
+ conversation_context += f"User: {msg.content}\n"
487
+ elif isinstance(msg, AIMessage):
488
+ conversation_context += f"Assistant: {msg.content}\n"
489
+
490
+ # Create analysis prompt - ENHANCED FOR BETTER EXTRACTION
491
+ analysis_prompt = ChatPromptTemplate.from_messages([
492
+ SystemMessage(content=f"""You are the Main Agent in an advanced multi-agent RAG system for audit report analysis.
493
+
494
+ 🎯 PRIMARY GOAL: Intelligently analyze user queries and determine the optimal conversation flow, whether that's answering directly, asking follow-ups, or proceeding to RAG retrieval.
495
+
496
+ 🧠 INTELLIGENCE LEVEL: You are a sophisticated conversational AI that can handle any type of user interaction - from greetings to complex audit queries.
497
+
498
+ πŸ“Š YOUR EXPERTISE: You specialize in analyzing audit reports from various sources (Local Government, Ministry, Hospital, etc.) across different years and districts in Uganda.
499
+
500
+ πŸ” AVAILABLE FILTERS:
501
+ - Years: {', '.join(self.year_whitelist)}
502
+ - Current year: {self.current_year}, Previous year: {self.previous_year}
503
+ - Sources: {', '.join(self.source_whitelist)}
504
+ - Districts: {', '.join(self.district_whitelist[:50])}... (and {len(self.district_whitelist)-50} more)
505
+
506
+ πŸŽ›οΈ UI FILTERS PROVIDED: {ui_filters}
507
+
508
+ πŸ“‹ UI FILTER HANDLING:
509
+ - If UI filters contain multiple values, extract ALL values
510
+ - UI filters take PRIORITY over conversation context
511
+
512
+ ⚠️ CRITICAL EXTRACTION RULES:
513
+
514
+ 1. **RELATIVE YEAR REFERENCES** - Convert to explicit years:
515
+ - "last couple of years" / "last 2 years" β†’ [{self.previous_year}, {str(int(self.previous_year)-1)}] (2 years)
516
+ - "last few years" / "last 3 years" β†’ [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}] (3 years)
517
+ - "recent years" β†’ [{self.previous_year}, {str(int(self.previous_year)-1)}, {str(int(self.previous_year)-2)}]
518
+ - "this year" β†’ ["{self.current_year}"]
519
+ - "last year" β†’ ["{self.previous_year}"]
520
+
521
+ 2. **DISTRICT TYPOS & ABBREVIATIONS** - Correct common mistakes:
522
+ - "KCC" or "KCCA" β†’ "KCCA" (Kampala Capital City Authority)
523
+ - "Padr" or "Padre" β†’ "Pader"
524
+ - "Kimboga" β†’ "Kiboga"
525
+ - "Kalagala" β†’ "Kalangala"
526
+
527
+ 3. **MULTIPLE VALUES** - Extract ALL mentioned items:
528
+ - If user says "Kampala, Kiboga, and Pader" β†’ extract ALL THREE districts
529
+ - If user says "2022, 2023, 2024" β†’ extract ALL THREE years
530
+ - Use "+" or "and" or "," as separators
531
+
532
+ 🧭 CONVERSATION FLOW INTELLIGENCE:
533
+
534
+ 1. **GREETINGS & GENERAL CHAT**:
535
+ - If user greets you, respond warmly and guide them
536
+
537
+ 2. **AUDIT QUERIES**:
538
+ - Extract values matching the available lists (with typo correction)
539
+ - DO NOT hallucinate values not mentioned by user
540
+
541
+ 3. **SMART FOLLOW-UP STRATEGY**:
542
+ - If user provides 2+ pieces of info, proceed to RAG
543
+ - If user provides 1 piece of info, ask for missing piece
544
+ - If user provides 0 pieces of info, ask for clarification
545
+ - NEVER ask the same question twice
546
+
547
+ 🎯 DECISION LOGIC:
548
+ - If query is a greeting/general chat β†’ needs_follow_up: true
549
+ - If query has 2+ pieces of info β†’ needs_follow_up: false, proceed to RAG
550
+ - If query has 1 piece of info β†’ needs_follow_up: true, ask for missing piece
551
+ - If query has 0 pieces of info β†’ needs_follow_up: true, ask for clarification
552
+
553
+ RESPOND WITH JSON ONLY:
554
+ {{
555
+ "has_district": boolean,
556
+ "has_source": boolean,
557
+ "has_year": boolean,
558
+ "extracted_district": "single or array or null",
559
+ "extracted_source": "single or array or null",
560
+ "extracted_year": "single or array or null",
561
+ "confidence_score": 0.0-1.0,
562
+ "needs_follow_up": boolean,
563
+ "follow_up_question": "question or null"
564
+ }}"""),
565
+ HumanMessage(content=f"""Query: {query}
566
+
567
+ Conversation Context:
568
+ {conversation_context}
569
+
570
+ CRITICAL: Analyze the FULL conversation context above.
571
+ Analyze this query using ONLY the exact values provided above:""")
572
+ ])
573
+
574
+ try:
575
+ response = self.llm.invoke(analysis_prompt.format_messages())
576
+
577
+ # Clean and parse JSON
578
+ content = response.content.strip()
579
+ if content.startswith("```json"):
580
+ content = content.replace("```json", "").replace("```", "").strip()
581
+ elif content.startswith("```"):
582
+ content = content.replace("```", "").strip()
583
+
584
+ # Remove comments
585
+ content = re.sub(r'//.*?$', '', content, flags=re.MULTILINE)
586
+ content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
587
+
588
+ analysis = json.loads(content)
589
+ logger.info(f"πŸ” QUERY ANALYSIS: βœ… Parsed successfully")
590
+
591
+ # Validate extracted values (same logic as v1)
592
+ extracted_district = analysis.get("extracted_district")
593
+ extracted_source = analysis.get("extracted_source")
594
+ extracted_year = analysis.get("extracted_year")
595
+
596
+ # Validate district
597
+ if extracted_district:
598
+ if isinstance(extracted_district, list):
599
+ valid_districts = []
600
+ for district in extracted_district:
601
+ normalized = self._normalize_district_name(district)
602
+ if normalized:
603
+ valid_districts.append(normalized)
604
+ extracted_district = valid_districts[0] if len(valid_districts) == 1 else (valid_districts if valid_districts else None)
605
+ else:
606
+ extracted_district = self._normalize_district_name(extracted_district)
607
+
608
+ # Validate source
609
+ if extracted_source:
610
+ if isinstance(extracted_source, list):
611
+ valid_sources = [s for s in extracted_source if s in self.source_whitelist]
612
+ extracted_source = valid_sources[0] if len(valid_sources) == 1 else (valid_sources if valid_sources else None)
613
+ else:
614
+ extracted_source = extracted_source if extracted_source in self.source_whitelist else None
615
+
616
+ # Validate year
617
+ if extracted_year:
618
+ if isinstance(extracted_year, list):
619
+ valid_years = [str(y) for y in extracted_year if str(y) in self.year_whitelist]
620
+ extracted_year = valid_years[0] if len(valid_years) == 1 else (valid_years if valid_years else None)
621
+ else:
622
+ extracted_year = str(extracted_year) if str(extracted_year) in self.year_whitelist else None
623
+
624
+ # Create QueryContext
625
+ context = QueryContext(
626
+ has_district=bool(extracted_district),
627
+ has_source=bool(extracted_source),
628
+ has_year=bool(extracted_year),
629
+ extracted_district=extracted_district,
630
+ extracted_source=extracted_source,
631
+ extracted_year=extracted_year,
632
+ ui_filters=ui_filters,
633
+ confidence_score=analysis.get("confidence_score", 0.0),
634
+ needs_follow_up=analysis.get("needs_follow_up", False),
635
+ follow_up_question=analysis.get("follow_up_question")
636
+ )
637
+
638
+ # If filenames provided, skip follow-ups
639
+ if ui_filters and ui_filters.get("filenames"):
640
+ context.needs_follow_up = False
641
+ context.follow_up_question = None
642
+
643
+ # Smart decision logic (same as v1)
644
+ if context.needs_follow_up:
645
+ info_count = sum([bool(context.extracted_district), bool(context.extracted_source), bool(context.extracted_year)])
646
+ query_lower = query.lower()
647
+ is_requesting_info = any(phrase in query_lower for phrase in [
648
+ "please provide", "could you provide", "can you provide",
649
+ "what is", "what are", "how much", "which", "what year",
650
+ "what district", "what source", "tell me about", "how were", "how was"
651
+ ])
652
+
653
+ if info_count >= 2 and not is_requesting_info:
654
+ context.needs_follow_up = False
655
+ context.follow_up_question = None
656
+ elif info_count >= 2 and is_requesting_info:
657
+ context.needs_follow_up = False
658
+ context.follow_up_question = None
659
+
660
+ return context
661
+
662
+ except Exception as e:
663
+ logger.error(f"❌ Query analysis failed: {e}")
664
+ return QueryContext(
665
+ has_district=bool(ui_filters.get("districts")),
666
+ has_source=bool(ui_filters.get("sources")),
667
+ has_year=bool(ui_filters.get("years")),
668
+ ui_filters=ui_filters,
669
+ confidence_score=0.5,
670
+ needs_follow_up=False
671
+ )
672
+
673
+ def _rewrite_query_for_rag(self, messages: List[Any], context: QueryContext) -> str:
674
+ """Rewrite query for optimal RAG retrieval - EXACT COPY FROM v1"""
675
+ logger.info("πŸ”„ QUERY REWRITING: Starting")
676
+
677
+ # Build conversation context
678
+ conversation_lines = []
679
+ for msg in messages[-6:]:
680
+ if isinstance(msg, HumanMessage):
681
+ conversation_lines.append(f"User: {msg.content}")
682
+ elif isinstance(msg, AIMessage):
683
+ conversation_lines.append(f"Assistant: {msg.content}")
684
+
685
+ convo_text = "\n".join(conversation_lines)
686
+
687
+ # Create rewrite prompt
688
+ rewrite_prompt = ChatPromptTemplate.from_messages([
689
+ SystemMessage(content="""You are a query rewriter for RAG retrieval.
690
+
691
+ GOAL: Create the best possible search query for document retrieval.
692
+
693
+ CRITICAL RULES:
694
+ 1. Focus on the core information need
695
+ 2. Remove meta-verbs like "summarize", "list", "compare", "how much", "what"
696
+ 3. DO NOT include filter details (years, districts, sources)
697
+ 4. Output ONE clear sentence suitable for vector search
698
+
699
+ EXAMPLES:
700
+ - "What are the top challenges in budget allocation?" β†’ "budget allocation challenges"
701
+ - "How were PDM administrative costs utilized?" β†’ "PDM administrative costs utilization"
702
+
703
+ OUTPUT FORMAT:
704
+ EXPLANATION: [reasoning]
705
+ QUERY: [one clean sentence]"""),
706
+ HumanMessage(content=f"""Conversation:
707
+ {convo_text}
708
+
709
+ Rewrite the best retrieval query:""")
710
+ ])
711
+
712
+ try:
713
+ response = self.llm.invoke(rewrite_prompt.format_messages())
714
+ rewritten = response.content.strip()
715
+
716
+ # Extract QUERY line
717
+ lines = rewritten.split('\n')
718
+ for line in lines:
719
+ if line.strip().startswith('QUERY:'):
720
+ query_line = line.replace('QUERY:', '').strip()
721
+ if len(query_line) > 5:
722
+ return query_line
723
+
724
+ # Fallback
725
+ for msg in reversed(messages):
726
+ if isinstance(msg, HumanMessage):
727
+ return msg.content
728
+ return "audit report information"
729
+
730
+ except Exception as e:
731
+ logger.error(f"❌ QUERY REWRITING: Error: {e}")
732
+ for msg in reversed(messages):
733
+ if isinstance(msg, HumanMessage):
734
+ return msg.content
735
+ return "audit report information"
736
+
737
+ def _build_filters(self, context: QueryContext) -> Dict[str, Any]:
738
+ """Build filters for RAG retrieval"""
739
+ logger.info(f"πŸ”§ FILTER BUILDING: Building filters from context: {context}")
740
+ filters = {}
741
+
742
+ # Check for filename filtering first
743
+ if context.ui_filters and context.ui_filters.get("filenames"):
744
+ filters["filenames"] = context.ui_filters["filenames"]
745
+ logger.info(f"πŸ”§ FILTER BUILDING: Using filename filter: {filters}")
746
+ return filters
747
+
748
+ # UI filters take priority
749
+ if context.ui_filters:
750
+ if context.ui_filters.get("sources"):
751
+ filters["sources"] = context.ui_filters["sources"]
752
+ if context.ui_filters.get("years"):
753
+ filters["year"] = context.ui_filters["years"]
754
+ if context.ui_filters.get("districts"):
755
+ # Title case for Qdrant compatibility
756
+ normalized_districts = [d.title() for d in context.ui_filters['districts']]
757
+ filters["district"] = normalized_districts
758
+
759
+ # Merge with extracted context
760
+ if not filters.get("district") and context.extracted_district:
761
+ if isinstance(context.extracted_district, list):
762
+ # Normalize each district - _normalize_district_name returns correct case
763
+ normalized = [self._normalize_district_name(d) for d in context.extracted_district]
764
+ filters["district"] = [d for d in normalized if d]
765
+ else:
766
+ normalized = self._normalize_district_name(context.extracted_district)
767
+ if normalized:
768
+ filters["district"] = [normalized]
769
+
770
+ if not filters.get("year") and context.extracted_year:
771
+ filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year
772
+
773
+ if not filters.get("sources") and context.extracted_source:
774
+ filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source
775
+ else:
776
+ # Use extracted context (no UI filters)
777
+ if context.extracted_source:
778
+ filters["sources"] = [context.extracted_source] if not isinstance(context.extracted_source, list) else context.extracted_source
779
+ if context.extracted_year:
780
+ filters["year"] = [context.extracted_year] if not isinstance(context.extracted_year, list) else context.extracted_year
781
+ if context.extracted_district:
782
+ if isinstance(context.extracted_district, list):
783
+ # Normalize each district - _normalize_district_name returns correct case
784
+ normalized = [self._normalize_district_name(d) for d in context.extracted_district]
785
+ filters["district"] = [d for d in normalized if d]
786
+ else:
787
+ normalized = self._normalize_district_name(context.extracted_district)
788
+ if normalized:
789
+ filters["district"] = [normalized]
790
+
791
+ return filters
792
+
793
+ @abstractmethod
794
+ def _generate_conversational_response(self, query: str, documents: List[Any], rag_answer: str, messages: List[Any], filters: Dict[str, Any] = None) -> str:
795
+ """Generate conversational response - must be implemented by subclasses"""
796
+ pass
797
+
798
+ @abstractmethod
799
+ def _generate_conversational_response_without_docs(self, query: str, messages: List[Any]) -> str:
800
+ """Generate response without documents - must be implemented by subclasses"""
801
+ pass
802
+
803
+ def chat(self, user_input: str, conversation_id: str = "default") -> Dict[str, Any]:
804
+ """Main chat interface"""
805
+ logger.info(f"πŸ’¬ MULTI-AGENT CHAT: Processing '{user_input[:50]}...'")
806
+
807
+ # Load conversation
808
+ conversation_file = self.conversations_dir / f"{conversation_id}.json"
809
+ conversation = self._load_conversation(conversation_file)
810
+
811
+ # Add user message
812
+ conversation["messages"].append(HumanMessage(content=user_input))
813
+
814
+ # Prepare state
815
+ state = MultiAgentState(
816
+ conversation_id=conversation_id,
817
+ messages=conversation["messages"],
818
+ current_query=user_input,
819
+ query_context=None,
820
+ rag_query=None,
821
+ rag_filters=None,
822
+ retrieved_documents=None,
823
+ final_response=None,
824
+ agent_logs=[],
825
+ conversation_context=conversation.get("context", {}),
826
+ session_start_time=conversation["session_start_time"],
827
+ last_ai_message_time=conversation["last_ai_message_time"]
828
+ )
829
+
830
+ # Run multi-agent graph
831
+ final_state = self.graph.invoke(state)
832
+
833
+ # Add AI response to conversation
834
+ if final_state["final_response"]:
835
+ conversation["messages"].append(AIMessage(content=final_state["final_response"]))
836
+
837
+ # Update conversation
838
+ conversation["last_ai_message_time"] = final_state["last_ai_message_time"]
839
+ conversation["context"] = final_state["conversation_context"]
840
+
841
+ # Save conversation
842
+ self._save_conversation(conversation_file, conversation)
843
+
844
+ # Return response
845
+ return {
846
+ 'response': final_state["final_response"],
847
+ 'rag_result': {
848
+ 'sources': final_state["retrieved_documents"] or [],
849
+ 'answer': final_state["final_response"]
850
+ },
851
+ 'agent_logs': final_state["agent_logs"],
852
+ 'actual_rag_query': final_state.get("rag_query", "")
853
+ }
854
+
855
+ def _load_conversation(self, conversation_file: Path) -> Dict[str, Any]:
856
+ """Load conversation from file"""
857
+ if conversation_file.exists():
858
+ try:
859
+ with open(conversation_file) as f:
860
+ data = json.load(f)
861
+ messages = []
862
+ for msg_data in data.get("messages", []):
863
+ if msg_data["type"] == "human":
864
+ messages.append(HumanMessage(content=msg_data["content"]))
865
+ elif msg_data["type"] == "ai":
866
+ messages.append(AIMessage(content=msg_data["content"]))
867
+ data["messages"] = messages
868
+ return data
869
+ except Exception as e:
870
+ logger.warning(f"Could not load conversation: {e}")
871
+
872
+ return {
873
+ "messages": [],
874
+ "session_start_time": time.time(),
875
+ "last_ai_message_time": time.time(),
876
+ "context": {}
877
+ }
878
+
879
+ def _save_conversation(self, conversation_file: Path, conversation: Dict[str, Any]):
880
+ """Save conversation to file"""
881
+ try:
882
+ conversation_file.parent.mkdir(parents=True, mode=0o777, exist_ok=True)
883
+
884
+ messages_data = []
885
+ for msg in conversation["messages"]:
886
+ if isinstance(msg, HumanMessage):
887
+ messages_data.append({"type": "human", "content": msg.content})
888
+ elif isinstance(msg, AIMessage):
889
+ messages_data.append({"type": "ai", "content": msg.content})
890
+
891
+ conversation_data = {
892
+ "messages": messages_data,
893
+ "session_start_time": conversation["session_start_time"],
894
+ "last_ai_message_time": conversation["last_ai_message_time"],
895
+ "context": conversation.get("context", {})
896
+ }
897
+
898
+ with open(conversation_file, 'w') as f:
899
+ json.dump(conversation_data, f, indent=2)
900
+
901
+ except Exception as e:
902
+ logger.error(f"Could not save conversation: {e}")
903
+