Rifqi Hafizuddin commited on
Commit
0707f2b
·
1 Parent(s): 9e16c22

[KM 436-439] adjust endpoint for new features

Browse files
src/agents/orchestration.py CHANGED
@@ -35,6 +35,11 @@ Intent Routing:
35
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
  - other -> needs_search=True, search_query=<standalone rewritten query>
 
 
 
 
 
38
  """),
39
  MessagesPlaceholder(variable_name="history"),
40
  ("user", "{message}")
 
35
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
  - other -> needs_search=True, search_query=<standalone rewritten query>
38
+
39
+ Source Routing (set source_hint):
40
+ - Columns, tables, sheets, data types, schema, row counts, statistics -> source_hint=schema
41
+ - Document content, paragraphs, reports, articles, text -> source_hint=document
42
+ - Unclear or spans both -> source_hint=both
43
  """),
44
  MessagesPlaceholder(variable_name="history"),
45
  ("user", "{message}")
src/api/v1/chat.py CHANGED
@@ -9,6 +9,9 @@ from src.db.postgres.models import ChatMessage, MessageSource
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
  from src.rag.retriever import retriever
 
 
 
12
  from src.db.redis.connection import get_redis
13
  from src.config.settings import settings
14
  from src.middlewares.logging import get_logger, log_execution
@@ -61,7 +64,7 @@ def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
61
  seen = set()
62
  sources = []
63
  for result in results:
64
- if "document_id" in result["metadata"]["data"]:
65
  meta = result["metadata"]
66
  key = (meta.get("data", {}).get("document_id"), meta.get("data", {}).get("page_label"))
67
  if key not in seen:
@@ -88,6 +91,22 @@ def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
88
  return sources
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  async def get_cached_response(redis, cache_key: str) -> Optional[str]:
92
  cached = await redis.get(cache_key)
93
  if cached:
@@ -182,16 +201,25 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
182
 
183
  if not intent_result.get("needs_search"):
184
  retrieval_task.cancel()
 
 
 
 
185
  raw_results = []
186
  else:
187
  search_query = intent_result.get("search_query", request.message)
188
  logger.info(f"Searching for: {search_query}")
189
  if search_query != request.message:
190
  retrieval_task.cancel()
 
 
 
 
191
  raw_results = await retriever.retrieve(
192
  query=search_query,
193
  user_id=request.user_id,
194
  db=db,
 
195
  )
196
  else:
197
  raw_results = await retrieval_task
@@ -199,6 +227,27 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
199
  context = _format_context(raw_results)
200
  sources = _extract_sources(raw_results)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # Step 3: Direct response for greetings / non-document intents
203
  if intent_result.get("direct_response"):
204
  response = intent_result["direct_response"]
 
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
  from src.rag.retriever import retriever
12
+ from src.rag.base import RetrievalResult
13
+ from src.query.query_executor import query_executor
14
+ from src.query.base import QueryResult
15
  from src.db.redis.connection import get_redis
16
  from src.config.settings import settings
17
  from src.middlewares.logging import get_logger, log_execution
 
64
  seen = set()
65
  sources = []
66
  for result in results:
67
+ if "document_id" in result["metadata"].get("data", {}):
68
  meta = result["metadata"]
69
  key = (meta.get("data", {}).get("document_id"), meta.get("data", {}).get("page_label"))
70
  if key not in seen:
 
91
  return sources
92
 
93
 
94
+ def _format_query_results(results: list[QueryResult]) -> str:
95
+ if not results:
96
+ return ""
97
+ lines = []
98
+ for r in results:
99
+ name = r.metadata.get("client_name", r.source_id)
100
+ lines.append(f"[Query result — {name}, tables: {r.table_or_file}]")
101
+ lines.append(f"SQL: {r.metadata.get('sql', '')}")
102
+ if r.columns and r.rows:
103
+ lines.append(" | ".join(r.columns))
104
+ for row in r.rows[:20]:
105
+ lines.append(" | ".join(str(row.get(c, "")) for c in r.columns))
106
+ lines.append(f"({r.row_count} rows total)\n")
107
+ return "\n".join(lines)
108
+
109
+
110
  async def get_cached_response(redis, cache_key: str) -> Optional[str]:
111
  cached = await redis.get(cache_key)
112
  if cached:
 
201
 
202
  if not intent_result.get("needs_search"):
203
  retrieval_task.cancel()
204
+ try:
205
+ await retrieval_task
206
+ except asyncio.CancelledError:
207
+ pass
208
  raw_results = []
209
  else:
210
  search_query = intent_result.get("search_query", request.message)
211
  logger.info(f"Searching for: {search_query}")
212
  if search_query != request.message:
213
  retrieval_task.cancel()
214
+ try:
215
+ await retrieval_task
216
+ except asyncio.CancelledError:
217
+ pass
218
  raw_results = await retriever.retrieve(
219
  query=search_query,
220
  user_id=request.user_id,
221
  db=db,
222
+ source_hint=intent_result.get("source_hint", "both"),
223
  )
224
  else:
225
  raw_results = await retrieval_task
 
227
  context = _format_context(raw_results)
228
  sources = _extract_sources(raw_results)
229
 
230
+ source_hint = intent_result.get("source_hint", "both")
231
+ if source_hint in ("schema", "both"):
232
+ retrieval_objects = [
233
+ RetrievalResult(
234
+ content=r["content"],
235
+ metadata=r["metadata"],
236
+ score=0.0,
237
+ source_type=r["metadata"].get("source_type", ""),
238
+ )
239
+ for r in raw_results
240
+ ]
241
+ query_results = await query_executor.execute(
242
+ results=retrieval_objects,
243
+ user_id=request.user_id,
244
+ db=db,
245
+ question=intent_result.get("search_query") or request.message,
246
+ )
247
+ query_context = _format_query_results(query_results)
248
+ if query_context:
249
+ context = query_context + "\n\n" + context
250
+
251
  # Step 3: Direct response for greetings / non-document intents
252
  if intent_result.get("direct_response"):
253
  response = intent_result["direct_response"]
src/api/v1/db_client.py CHANGED
@@ -458,7 +458,7 @@ async def ingest_database_client(
458
  db_type=client.db_type,
459
  credentials=creds,
460
  ) as engine:
461
- total = await db_pipeline_service.run(user_id=user_id, engine=engine)
462
  except NotImplementedError as e:
463
  raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
464
  except Exception as e:
 
458
  db_type=client.db_type,
459
  credentials=creds,
460
  ) as engine:
461
+ total = await db_pipeline_service.run(user_id=user_id, client_id=client_id, engine=engine)
462
  except NotImplementedError as e:
463
  raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
464
  except Exception as e: