ishaq101 Claude Sonnet 4.6 commited on
Commit
7323952
·
1 Parent(s): 50c04b4

feat: persist RAG sources to DB and return them in room detail

Browse files

- Add MessageSource model and message_sources table
- save_messages now stores sources linked to assistant message
- GET /room/{room_id} returns sources per message (empty [] for user role)
- Fix page_label int->str cast to prevent asyncpg DataError
- Orchestrator now receives conversation history for context-aware query rewriting
- Chatbot includes full conversation history for coherent multi-turn responses

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

.gitignore CHANGED
@@ -32,4 +32,5 @@ playground_retriever.py
32
  playground_chat.py
33
  playground_flush_cache.py
34
  playground_create_user.py
35
- API_CONTRACT.md
 
 
32
  playground_chat.py
33
  playground_flush_cache.py
34
  playground_create_user.py
35
+ API_CONTRACT.md
36
+ context_engineering/
src/agents/orchestration.py CHANGED
@@ -1,10 +1,10 @@
1
  """Orchestrator agent for intent recognition and planning."""
2
 
3
  from langchain_openai import AzureChatOpenAI
4
- from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_core.output_parsers import JsonOutputParser
6
  from src.config.settings import settings
7
  from src.middlewares.logging import get_logger
 
8
 
9
  logger = get_logger("orchestrator")
10
 
@@ -22,43 +22,43 @@ class OrchestratorAgent:
22
  )
23
 
24
  self.prompt = ChatPromptTemplate.from_messages([
25
- ("system", """You are an orchestrator agent. Analyze user's message and determine:
26
-
27
- 1. What is user's intent? (question, greeting, goodbye, other)
28
- 2. Do we need to search user's documents for relevant information?
29
- 3. If search is needed, what query should we use?
30
- 4. If no search needed, provide a direct response.
31
-
32
- Respond in JSON format with these fields:
33
- - intent: string (question, greeting, goodbye, other)
34
- - needs_search: boolean
35
- - search_query: string (if needed)
36
- - direct_response: string (if no search needed)
37
-
38
  Intent Routing:
39
- - question -> needs_search=True, search_query=message
40
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
41
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
42
- - other -> needs_search=True, search_query=message
43
  """),
 
44
  ("user", "{message}")
45
  ])
46
 
47
- self.chain = (
48
- self.prompt
49
- | self.llm
50
- | JsonOutputParser()
51
- )
52
 
53
- async def analyze_message(self, message: str) -> dict:
54
- """Analyze user message and determine next actions."""
 
 
 
55
  try:
56
  logger.info(f"Analyzing message: {message[:50]}...")
57
 
58
- result = await self.chain.ainvoke({"message": message})
 
59
 
60
- logger.info(f"Intent: {result.get('intent')}, Needs search: {result.get('needs_search')}")
61
- return result
62
 
63
  except Exception as e:
64
  logger.error("Message analysis failed", error=str(e))
 
1
  """Orchestrator agent for intent recognition and planning."""
2
 
3
  from langchain_openai import AzureChatOpenAI
4
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
 
5
  from src.config.settings import settings
6
  from src.middlewares.logging import get_logger
7
+ from src.models.structured_output import IntentClassification
8
 
9
  logger = get_logger("orchestrator")
10
 
 
22
  )
23
 
24
  self.prompt = ChatPromptTemplate.from_messages([
25
+ ("system", """You are an orchestrator agent. You receive recent conversation history and the user's latest message.
26
+
27
+ Your task:
28
+ 1. Determine intent: question, greeting, goodbye, or other
29
+ 2. Decide whether to search the user's documents (needs_search)
30
+ 3. If search is needed, rewrite the user's message into a STANDALONE search query that incorporates necessary context from conversation history. If the user says "tell me more" or "how many papers?", the search_query must spell out the full topic explicitly from history.
31
+ 4. If no search needed, provide a short direct_response (plain text only, no markdown formatting).
32
+
 
 
 
 
 
33
  Intent Routing:
34
+ - question -> needs_search=True, search_query=<standalone rewritten query>
35
  - greeting -> needs_search=False, direct_response="Hello! How can I assist you today?"
36
  - goodbye -> needs_search=False, direct_response="Goodbye! Have a great day!"
37
+ - other -> needs_search=True, search_query=<standalone rewritten query>
38
  """),
39
+ MessagesPlaceholder(variable_name="history"),
40
  ("user", "{message}")
41
  ])
42
 
43
+ # with_structured_output uses function calling — guarantees valid schema regardless of LLM response style
44
+ self.chain = self.prompt | self.llm.with_structured_output(IntentClassification)
45
+
46
+ async def analyze_message(self, message: str, history: list = None) -> dict:
47
+ """Analyze user message and determine next actions.
48
 
49
+ Args:
50
+ message: The current user message.
51
+ history: Recent conversation as LangChain BaseMessage objects (oldest-first).
52
+ Used to rewrite ambiguous follow-ups into standalone search queries.
53
+ """
54
  try:
55
  logger.info(f"Analyzing message: {message[:50]}...")
56
 
57
+ history_messages = history or []
58
+ result: IntentClassification = await self.chain.ainvoke({"message": message, "history": history_messages})
59
 
60
+ logger.info(f"Intent: {result.intent}, Needs search: {result.needs_search}, Search query: {result.search_query[:50] if result.search_query else ''}")
61
+ return result.model_dump()
62
 
63
  except Exception as e:
64
  logger.error("Message analysis failed", error=str(e))
src/api/v1/chat.py CHANGED
@@ -5,7 +5,7 @@ import uuid
5
  from fastapi import APIRouter, Depends, HTTPException
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
  from src.db.postgres.connection import get_db
8
- from src.db.postgres.models import ChatMessage
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
  from src.rag.retriever import retriever
@@ -13,7 +13,8 @@ from src.db.redis.connection import get_redis
13
  from src.config.settings import settings
14
  from src.middlewares.logging import get_logger, log_execution
15
  from sse_starlette.sse import EventSourceResponse
16
- from langchain_core.messages import HumanMessage
 
17
  from pydantic import BaseModel
18
  from typing import List, Dict, Any, Optional
19
  import json
@@ -83,10 +84,41 @@ async def cache_response(redis, cache_key: str, response: str):
83
  await redis.setex(cache_key, 86400, json.dumps(response))
84
 
85
 
86
- async def save_messages(db: AsyncSession, room_id: str, user_content: str, assistant_content: str):
87
- """Persist user and assistant messages to chat_messages table."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content))
89
- db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="assistant", content=assistant_content))
 
 
 
 
 
 
 
 
 
 
90
  await db.commit()
91
 
92
 
@@ -102,7 +134,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
102
  """
103
  redis = await get_redis()
104
 
105
- cache_key = f"{settings.redis_prefix}chat:{request.user_id}:{request.message}"
106
  cached = await get_cached_response(redis, cache_key)
107
  if cached:
108
  logger.info("Returning cached response")
@@ -123,11 +155,15 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
123
  sources: List[Dict[str, Any]] = []
124
 
125
  if intent_result is None:
126
- # Step 2: Launch retrieval optimistically while orchestrator decides in parallel
127
  retrieval_task = asyncio.create_task(
128
  retriever.retrieve(request.message, request.user_id, db)
129
  )
130
- intent_result = await orchestrator.analyze_message(request.message)
 
 
 
 
131
 
132
  if not intent_result.get("needs_search"):
133
  retrieval_task.cancel()
@@ -152,7 +188,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
152
  if intent_result.get("direct_response"):
153
  response = intent_result["direct_response"]
154
  await cache_response(redis, cache_key, response)
155
- await save_messages(db, request.room_id, request.message, response)
156
 
157
  async def stream_direct():
158
  yield {"event": "sources", "data": json.dumps([])}
@@ -161,7 +197,9 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
161
  return EventSourceResponse(stream_direct())
162
 
163
  # Step 4: Stream answer token-by-token as LLM generates it
164
- messages = [HumanMessage(content=request.message)]
 
 
165
 
166
  async def stream_response():
167
  full_response = ""
@@ -171,7 +209,7 @@ async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
171
  yield {"event": "chunk", "data": token}
172
  yield {"event": "done", "data": ""}
173
  await cache_response(redis, cache_key, full_response)
174
- await save_messages(db, request.room_id, request.message, full_response)
175
 
176
  return EventSourceResponse(stream_response())
177
 
 
5
  from fastapi import APIRouter, Depends, HTTPException
6
  from sqlalchemy.ext.asyncio import AsyncSession
7
  from src.db.postgres.connection import get_db
8
+ from src.db.postgres.models import ChatMessage, MessageSource
9
  from src.agents.orchestration import orchestrator
10
  from src.agents.chatbot import chatbot
11
  from src.rag.retriever import retriever
 
13
  from src.config.settings import settings
14
  from src.middlewares.logging import get_logger, log_execution
15
  from sse_starlette.sse import EventSourceResponse
16
+ from langchain_core.messages import HumanMessage, AIMessage
17
+ from sqlalchemy import select
18
  from pydantic import BaseModel
19
  from typing import List, Dict, Any, Optional
20
  import json
 
84
  await redis.setex(cache_key, 86400, json.dumps(response))
85
 
86
 
87
+ async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
88
+ """Load recent chat messages for a room as LangChain message objects (oldest-first)."""
89
+ result = await db.execute(
90
+ select(ChatMessage)
91
+ .where(ChatMessage.room_id == room_id)
92
+ .order_by(ChatMessage.created_at.asc())
93
+ .limit(limit)
94
+ )
95
+ rows = result.scalars().all()
96
+ return [
97
+ HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content)
98
+ for row in rows
99
+ ]
100
+
101
+
102
+ async def save_messages(
103
+ db: AsyncSession,
104
+ room_id: str,
105
+ user_content: str,
106
+ assistant_content: str,
107
+ sources: Optional[List[Dict[str, Any]]] = None,
108
+ ):
109
+ """Persist user and assistant messages, and attach sources to the assistant message."""
110
  db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content))
111
+ assistant_id = str(uuid.uuid4())
112
+ db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content))
113
+ for src in (sources or []):
114
+ page = src.get("page_label")
115
+ db.add(MessageSource(
116
+ id=str(uuid.uuid4()),
117
+ message_id=assistant_id,
118
+ document_id=src.get("document_id"),
119
+ filename=src.get("filename"),
120
+ page_label=str(page) if page is not None else None,
121
+ ))
122
  await db.commit()
123
 
124
 
 
134
  """
135
  redis = await get_redis()
136
 
137
+ cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
138
  cached = await get_cached_response(redis, cache_key)
139
  if cached:
140
  logger.info("Returning cached response")
 
155
  sources: List[Dict[str, Any]] = []
156
 
157
  if intent_result is None:
158
+ # Step 2: Launch retrieval and history loading in parallel, then run orchestrator
159
  retrieval_task = asyncio.create_task(
160
  retriever.retrieve(request.message, request.user_id, db)
161
  )
162
+ history_task = asyncio.create_task(
163
+ load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
164
+ )
165
+ history = await history_task # fast DB query (<100ms), done before orchestrator finishes
166
+ intent_result = await orchestrator.analyze_message(request.message, history)
167
 
168
  if not intent_result.get("needs_search"):
169
  retrieval_task.cancel()
 
188
  if intent_result.get("direct_response"):
189
  response = intent_result["direct_response"]
190
  await cache_response(redis, cache_key, response)
191
+ await save_messages(db, request.room_id, request.message, response, sources=[])
192
 
193
  async def stream_direct():
194
  yield {"event": "sources", "data": json.dumps([])}
 
197
  return EventSourceResponse(stream_direct())
198
 
199
  # Step 4: Stream answer token-by-token as LLM generates it
200
+ # Load full history (10 msgs) for chatbot — richer context than the 6 used by orchestrator
201
+ full_history = await load_history(db, request.room_id, limit=10)
202
+ messages = full_history + [HumanMessage(content=request.message)]
203
 
204
  async def stream_response():
205
  full_response = ""
 
209
  yield {"event": "chunk", "data": token}
210
  yield {"event": "done", "data": ""}
211
  await cache_response(redis, cache_key, full_response)
212
+ await save_messages(db, request.room_id, request.message, full_response, sources=sources)
213
 
214
  return EventSourceResponse(stream_response())
215
 
src/api/v1/room.py CHANGED
@@ -5,10 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
5
  from sqlalchemy import select
6
  from sqlalchemy.orm import selectinload
7
  from src.db.postgres.connection import get_db
8
- from src.db.postgres.models import Room, ChatMessage
9
  from src.middlewares.logging import get_logger, log_execution
10
  from pydantic import BaseModel
11
- from typing import List
12
  from datetime import datetime
13
  import uuid
14
 
@@ -17,11 +17,18 @@ logger = get_logger("room_api")
17
  router = APIRouter(prefix="/api/v1", tags=["Rooms"])
18
 
19
 
 
 
 
 
 
 
20
  class ChatMessageResponse(BaseModel):
21
  id: str
22
  role: str
23
  content: str
24
  created_at: str
 
25
 
26
 
27
  class RoomResponse(BaseModel):
@@ -72,7 +79,7 @@ async def get_room(
72
  result = await db.execute(
73
  select(Room)
74
  .where(Room.id == room_id)
75
- .options(selectinload(Room.messages))
76
  )
77
  room = result.scalars().first()
78
 
@@ -94,7 +101,15 @@ async def get_room(
94
  id=msg.id,
95
  role=msg.role,
96
  content=msg.content,
97
- created_at=msg.created_at.isoformat()
 
 
 
 
 
 
 
 
98
  )
99
  for msg in messages
100
  ]
 
5
  from sqlalchemy import select
6
  from sqlalchemy.orm import selectinload
7
  from src.db.postgres.connection import get_db
8
+ from src.db.postgres.models import Room, ChatMessage, MessageSource
9
  from src.middlewares.logging import get_logger, log_execution
10
  from pydantic import BaseModel
11
+ from typing import List, Optional
12
  from datetime import datetime
13
  import uuid
14
 
 
17
  router = APIRouter(prefix="/api/v1", tags=["Rooms"])
18
 
19
 
20
+ class MessageSourceResponse(BaseModel):
21
+ document_id: Optional[str]
22
+ filename: Optional[str]
23
+ page_label: Optional[str]
24
+
25
+
26
  class ChatMessageResponse(BaseModel):
27
  id: str
28
  role: str
29
  content: str
30
  created_at: str
31
+ sources: List[MessageSourceResponse] = []
32
 
33
 
34
  class RoomResponse(BaseModel):
 
79
  result = await db.execute(
80
  select(Room)
81
  .where(Room.id == room_id)
82
+ .options(selectinload(Room.messages).selectinload(ChatMessage.sources))
83
  )
84
  room = result.scalars().first()
85
 
 
101
  id=msg.id,
102
  role=msg.role,
103
  content=msg.content,
104
+ created_at=msg.created_at.isoformat(),
105
+ sources=[
106
+ MessageSourceResponse(
107
+ document_id=src.document_id,
108
+ filename=src.filename,
109
+ page_label=src.page_label,
110
+ )
111
+ for src in msg.sources
112
+ ],
113
  )
114
  for msg in messages
115
  ]
src/config/agents/system_prompt.md CHANGED
@@ -15,4 +15,13 @@ When no document context is provided:
15
  - Provide general assistance
16
  - Let the user know if you need more context to help better
17
 
 
 
 
 
18
  Always be professional, helpful, and accurate.
 
 
 
 
 
 
15
  - Provide general assistance
16
  - Let the user know if you need more context to help better
17
 
18
+ When the answer need markdown formating:
19
+ - Use valid and tidy formatting
20
+ - Avoid over-formating and emoji
21
+
22
  Always be professional, helpful, and accurate.
23
+
24
+ You have access to the conversation history provided in the messages above. Use it to:
25
+ - Maintain context across multiple turns (resolve references like "it", "that", "them" using earlier messages)
26
+ - Avoid repeating information already established in the conversation
27
+ - Answer follow-up questions coherently without asking the user to restate prior context
src/db/postgres/init_db.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  from sqlalchemy import text
4
  from src.db.postgres.connection import engine, Base
5
- from src.db.postgres.models import Document, Room, ChatMessage, User
6
 
7
 
8
  async def init_db():
 
2
 
3
  from sqlalchemy import text
4
  from src.db.postgres.connection import engine, Base
5
+ from src.db.postgres.models import Document, Room, ChatMessage, User, MessageSource
6
 
7
 
8
  async def init_db():
src/db/postgres/models.py CHANGED
@@ -66,3 +66,18 @@ class ChatMessage(Base):
66
  created_at = Column(DateTime(timezone=True), server_default=func.now())
67
 
68
  room = relationship("Room", back_populates="messages")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  created_at = Column(DateTime(timezone=True), server_default=func.now())
67
 
68
  room = relationship("Room", back_populates="messages")
69
+ sources = relationship("MessageSource", back_populates="message", cascade="all, delete-orphan")
70
+
71
+
72
+ class MessageSource(Base):
73
+ """Sources (RAG references) attached to an assistant message."""
74
+ __tablename__ = "message_sources"
75
+
76
+ id = Column(String, primary_key=True, default=lambda: str(uuid4()))
77
+ message_id = Column(String, ForeignKey("chat_messages.id", ondelete="CASCADE"), nullable=False, index=True)
78
+ document_id = Column(String)
79
+ filename = Column(Text)
80
+ page_label = Column(Text)
81
+ created_at = Column(DateTime(timezone=True), server_default=func.now())
82
+
83
+ message = relationship("ChatMessage", back_populates="sources")