Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
conversation history structure
Browse files- app/main.py +82 -11
- app/models.py +5 -1
- app/nodes.py +21 -9
app/main.py
CHANGED
|
@@ -66,22 +66,40 @@ compiled_graph = workflow.compile()
|
|
| 66 |
#----------------------------------------
|
| 67 |
|
| 68 |
async def chatui_adapter(data):
|
| 69 |
-
"""Text-only adapter for ChatUI"""
|
| 70 |
try:
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
full_response = ""
|
| 76 |
sources_collected = None
|
| 77 |
|
| 78 |
async for result in process_query_streaming(
|
| 79 |
-
query=
|
| 80 |
file_upload=None,
|
| 81 |
reports_filter="",
|
| 82 |
sources_filter="",
|
| 83 |
subtype_filter="",
|
| 84 |
-
year_filter=""
|
|
|
|
| 85 |
):
|
| 86 |
if isinstance(result, dict):
|
| 87 |
result_type = result.get("type", "data")
|
|
@@ -111,9 +129,23 @@ async def chatui_adapter(data):
|
|
| 111 |
|
| 112 |
|
| 113 |
async def chatui_file_adapter(data):
|
| 114 |
-
"""File upload adapter for ChatUI
|
| 115 |
try:
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
files = getattr(data, 'files', None) if hasattr(data, 'files') else data.get('files', None)
|
| 118 |
|
| 119 |
file_content = None
|
|
@@ -135,14 +167,15 @@ async def chatui_file_adapter(data):
|
|
| 135 |
sources_collected = None
|
| 136 |
|
| 137 |
async for result in process_query_streaming(
|
| 138 |
-
query=
|
| 139 |
file_content=file_content,
|
| 140 |
filename=filename,
|
| 141 |
reports_filter="",
|
| 142 |
sources_filter="",
|
| 143 |
subtype_filter="",
|
| 144 |
year_filter="",
|
| 145 |
-
output_format="structured"
|
|
|
|
| 146 |
):
|
| 147 |
if isinstance(result, dict):
|
| 148 |
result_type = result.get("type", "data")
|
|
@@ -153,7 +186,6 @@ async def chatui_file_adapter(data):
|
|
| 153 |
elif result_type == "sources":
|
| 154 |
sources_collected = content
|
| 155 |
elif result_type == "end":
|
| 156 |
-
# Send sources at the end, like the text-only adapter
|
| 157 |
if sources_collected:
|
| 158 |
sources_text = "\n\n**Sources:**\n"
|
| 159 |
for i, source in enumerate(sources_collected, 1):
|
|
@@ -176,6 +208,45 @@ async def chatui_file_adapter(data):
|
|
| 176 |
yield f"Error: {str(e)}"
|
| 177 |
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
#----------------------------------------
|
| 180 |
# FASTAPI SETUP - for future use
|
| 181 |
#----------------------------------------
|
|
|
|
| 66 |
#----------------------------------------
|
| 67 |
|
| 68 |
async def chatui_adapter(data):
|
| 69 |
+
"""Text-only adapter for ChatUI with structured message support"""
|
| 70 |
try:
|
| 71 |
+
# Extract query - prefer structured messages over legacy text field
|
| 72 |
+
if hasattr(data, 'messages') and data.messages:
|
| 73 |
+
messages = data.messages
|
| 74 |
+
# Extract latest user query
|
| 75 |
+
user_messages = [msg for msg in messages if msg.role == 'user']
|
| 76 |
+
query = user_messages[-1].content if user_messages else ""
|
| 77 |
+
|
| 78 |
+
# Log conversation context
|
| 79 |
+
logger.info(f"Processing query: {query}")
|
| 80 |
+
logger.info(f"Total messages in conversation: {len(messages)}")
|
| 81 |
+
logger.info(f"User messages: {len(user_messages)}, Assistant messages: {len([m for m in messages if m.role == 'assistant'])}")
|
| 82 |
+
|
| 83 |
+
# Optional: Build conversation context for generation (last N turns)
|
| 84 |
+
conversation_context = build_conversation_context(messages, max_turns=3)
|
| 85 |
+
logger.info(f"Conversation context: {len(conversation_context)} characters")
|
| 86 |
+
else:
|
| 87 |
+
# Fallback to legacy text field
|
| 88 |
+
query = data.text if hasattr(data, 'text') else data.get('text', '')
|
| 89 |
+
conversation_context = None
|
| 90 |
+
logger.info(f"Processing query (legacy): {query}")
|
| 91 |
|
| 92 |
full_response = ""
|
| 93 |
sources_collected = None
|
| 94 |
|
| 95 |
async for result in process_query_streaming(
|
| 96 |
+
query=query,
|
| 97 |
file_upload=None,
|
| 98 |
reports_filter="",
|
| 99 |
sources_filter="",
|
| 100 |
subtype_filter="",
|
| 101 |
+
year_filter="",
|
| 102 |
+
conversation_context=conversation_context # Pass to processing function
|
| 103 |
):
|
| 104 |
if isinstance(result, dict):
|
| 105 |
result_type = result.get("type", "data")
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
async def chatui_file_adapter(data):
|
| 132 |
+
"""File upload adapter for ChatUI with structured message support"""
|
| 133 |
try:
|
| 134 |
+
# Extract query - prefer structured messages
|
| 135 |
+
if hasattr(data, 'messages') and data.messages:
|
| 136 |
+
messages = data.messages
|
| 137 |
+
user_messages = [msg for msg in messages if msg.role == 'user']
|
| 138 |
+
query = user_messages[-1].content if user_messages else ""
|
| 139 |
+
|
| 140 |
+
logger.info(f"Processing query: {query}")
|
| 141 |
+
logger.info(f"Total messages: {len(messages)}")
|
| 142 |
+
|
| 143 |
+
conversation_context = build_conversation_context(messages, max_turns=3)
|
| 144 |
+
else:
|
| 145 |
+
query = data.text if hasattr(data, 'text') else data.get('text', '')
|
| 146 |
+
conversation_context = None
|
| 147 |
+
logger.info(f"Processing query (legacy): {query}")
|
| 148 |
+
|
| 149 |
files = getattr(data, 'files', None) if hasattr(data, 'files') else data.get('files', None)
|
| 150 |
|
| 151 |
file_content = None
|
|
|
|
| 167 |
sources_collected = None
|
| 168 |
|
| 169 |
async for result in process_query_streaming(
|
| 170 |
+
query=query,
|
| 171 |
file_content=file_content,
|
| 172 |
filename=filename,
|
| 173 |
reports_filter="",
|
| 174 |
sources_filter="",
|
| 175 |
subtype_filter="",
|
| 176 |
year_filter="",
|
| 177 |
+
output_format="structured",
|
| 178 |
+
conversation_context=conversation_context
|
| 179 |
):
|
| 180 |
if isinstance(result, dict):
|
| 181 |
result_type = result.get("type", "data")
|
|
|
|
| 186 |
elif result_type == "sources":
|
| 187 |
sources_collected = content
|
| 188 |
elif result_type == "end":
|
|
|
|
| 189 |
if sources_collected:
|
| 190 |
sources_text = "\n\n**Sources:**\n"
|
| 191 |
for i, source in enumerate(sources_collected, 1):
|
|
|
|
| 208 |
yield f"Error: {str(e)}"
|
| 209 |
|
| 210 |
|
| 211 |
+
def build_conversation_context(messages: List, max_turns: int = 3, max_chars: int = 2000) -> str:
|
| 212 |
+
"""
|
| 213 |
+
Build conversation context from structured messages.
|
| 214 |
+
Keeps the most recent turns within character budget.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
messages: List of Message objects
|
| 218 |
+
max_turns: Maximum number of conversation turns (user+assistant pairs) to include
|
| 219 |
+
max_chars: Maximum total characters in context
|
| 220 |
+
"""
|
| 221 |
+
context_parts = []
|
| 222 |
+
char_count = 0
|
| 223 |
+
turn_count = 0
|
| 224 |
+
|
| 225 |
+
# Process messages in reverse to keep most recent
|
| 226 |
+
for msg in reversed(messages):
|
| 227 |
+
role_label = msg.role.upper()
|
| 228 |
+
content = msg.content
|
| 229 |
+
|
| 230 |
+
# Estimate if adding this message would exceed limits
|
| 231 |
+
msg_text = f"{role_label}: {content}"
|
| 232 |
+
msg_chars = len(msg_text)
|
| 233 |
+
|
| 234 |
+
if char_count + msg_chars > max_chars:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
if msg.role in ['user', 'assistant']:
|
| 238 |
+
if msg.role == 'user':
|
| 239 |
+
turn_count += 1
|
| 240 |
+
if turn_count > max_turns:
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
context_parts.insert(0, msg_text)
|
| 244 |
+
char_count += msg_chars
|
| 245 |
+
|
| 246 |
+
context = "\n\n".join(context_parts)
|
| 247 |
+
logger.info(f"Built conversation context: {turn_count} turns, {char_count} chars")
|
| 248 |
+
return context
|
| 249 |
+
|
| 250 |
#----------------------------------------
|
| 251 |
# FASTAPI SETUP - for future use
|
| 252 |
#----------------------------------------
|
app/models.py
CHANGED
|
@@ -23,9 +23,13 @@ class GraphState(TypedDict):
|
|
| 23 |
|
| 24 |
class ChatUIInput(BaseModel):
|
| 25 |
"""Input model for text-only ChatUI requests"""
|
| 26 |
-
text: str
|
|
|
|
|
|
|
| 27 |
|
| 28 |
class ChatUIFileInput(BaseModel):
|
| 29 |
"""Input model for ChatUI requests with file attachments"""
|
| 30 |
text: str
|
| 31 |
files: Optional[List[Dict[str, Any]]] = None
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class ChatUIInput(BaseModel):
|
| 25 |
"""Input model for text-only ChatUI requests"""
|
| 26 |
+
text: str # Legacy: full concatenated prompt (for backward compatibility)
|
| 27 |
+
messages: Optional[List[Message]] = None # Structured conversation history
|
| 28 |
+
preprompt: Optional[str] = None
|
| 29 |
|
| 30 |
class ChatUIFileInput(BaseModel):
|
| 31 |
"""Input model for ChatUI requests with file attachments"""
|
| 32 |
text: str
|
| 33 |
files: Optional[List[Dict[str, Any]]] = None
|
| 34 |
+
messages: Optional[List[Message]] = None # Structured conversation history
|
| 35 |
+
preprompt: Optional[str] = None
|
app/nodes.py
CHANGED
|
@@ -403,11 +403,17 @@ async def process_query_streaming(
|
|
| 403 |
reports_filter: str = "",
|
| 404 |
sources_filter: str = "",
|
| 405 |
subtype_filter: str = "",
|
| 406 |
-
year_filter: str = "",
|
| 407 |
-
output_format: str = "structured"
|
|
|
|
| 408 |
):
|
| 409 |
"""
|
| 410 |
-
Unified streaming function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
"""
|
| 412 |
# Handle file_upload if provided
|
| 413 |
if file_upload is not None:
|
|
@@ -427,10 +433,15 @@ async def process_query_streaming(
|
|
| 427 |
start_time = datetime.now()
|
| 428 |
session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}"
|
| 429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
try:
|
| 431 |
# Build initial state
|
| 432 |
initial_state = {
|
| 433 |
-
"query": query,
|
| 434 |
"context": "",
|
| 435 |
"ingestor_context": "",
|
| 436 |
"result": "",
|
|
@@ -443,23 +454,23 @@ async def process_query_streaming(
|
|
| 443 |
"filename": filename,
|
| 444 |
"file_type": "unknown",
|
| 445 |
"workflow_type": "standard",
|
|
|
|
| 446 |
"metadata": {
|
| 447 |
"session_id": session_id,
|
| 448 |
"start_time": start_time.isoformat(),
|
| 449 |
-
"has_file_attachment": file_content is not None
|
|
|
|
| 450 |
}
|
| 451 |
}
|
| 452 |
|
| 453 |
# Execute workflow nodes
|
| 454 |
if file_content and filename:
|
| 455 |
-
# File present: detect type and process
|
| 456 |
state = merge_state(initial_state, detect_file_type_node(initial_state))
|
| 457 |
state = merge_state(state, ingest_node(state))
|
| 458 |
|
| 459 |
workflow_type = route_workflow(state)
|
| 460 |
|
| 461 |
if workflow_type == "direct_output":
|
| 462 |
-
# NEW file with direct output enabled: show results and return
|
| 463 |
final_state = direct_output_node(state)
|
| 464 |
if output_format == "structured":
|
| 465 |
yield {"type": "data", "content": final_state["result"]}
|
|
@@ -468,13 +479,14 @@ async def process_query_streaming(
|
|
| 468 |
yield final_state["result"]
|
| 469 |
return
|
| 470 |
else:
|
| 471 |
-
#
|
| 472 |
state = merge_state(state, retrieve_node(state))
|
| 473 |
else:
|
| 474 |
-
# No file:
|
| 475 |
state = merge_state(initial_state, retrieve_node(initial_state))
|
| 476 |
|
| 477 |
# Generate response with streaming
|
|
|
|
| 478 |
sources_collected = None
|
| 479 |
accumulated_response = "" if output_format == "gradio" else None
|
| 480 |
|
|
|
|
| 403 |
reports_filter: str = "",
|
| 404 |
sources_filter: str = "",
|
| 405 |
subtype_filter: str = "",
|
| 406 |
+
year_filter: str = "",
|
| 407 |
+
output_format: str = "structured",
|
| 408 |
+
conversation_context: Optional[str] = None # NEW: conversation context
|
| 409 |
):
|
| 410 |
"""
|
| 411 |
+
Unified streaming function with conversation context support.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
query: Latest user query
|
| 415 |
+
conversation_context: Optional conversation history for generation context
|
| 416 |
+
... (other args remain the same)
|
| 417 |
"""
|
| 418 |
# Handle file_upload if provided
|
| 419 |
if file_upload is not None:
|
|
|
|
| 433 |
start_time = datetime.now()
|
| 434 |
session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}"
|
| 435 |
|
| 436 |
+
# Log retrieval strategy
|
| 437 |
+
logger.info(f"Retrieval query: {query[:100]}...")
|
| 438 |
+
if conversation_context:
|
| 439 |
+
logger.info(f"Generation will use conversation context ({len(conversation_context)} chars)")
|
| 440 |
+
|
| 441 |
try:
|
| 442 |
# Build initial state
|
| 443 |
initial_state = {
|
| 444 |
+
"query": query, # Use ONLY latest query for retrieval
|
| 445 |
"context": "",
|
| 446 |
"ingestor_context": "",
|
| 447 |
"result": "",
|
|
|
|
| 454 |
"filename": filename,
|
| 455 |
"file_type": "unknown",
|
| 456 |
"workflow_type": "standard",
|
| 457 |
+
"conversation_context": conversation_context, # Store for generation
|
| 458 |
"metadata": {
|
| 459 |
"session_id": session_id,
|
| 460 |
"start_time": start_time.isoformat(),
|
| 461 |
+
"has_file_attachment": file_content is not None,
|
| 462 |
+
"has_conversation_context": conversation_context is not None
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
| 466 |
# Execute workflow nodes
|
| 467 |
if file_content and filename:
|
|
|
|
| 468 |
state = merge_state(initial_state, detect_file_type_node(initial_state))
|
| 469 |
state = merge_state(state, ingest_node(state))
|
| 470 |
|
| 471 |
workflow_type = route_workflow(state)
|
| 472 |
|
| 473 |
if workflow_type == "direct_output":
|
|
|
|
| 474 |
final_state = direct_output_node(state)
|
| 475 |
if output_format == "structured":
|
| 476 |
yield {"type": "data", "content": final_state["result"]}
|
|
|
|
| 479 |
yield final_state["result"]
|
| 480 |
return
|
| 481 |
else:
|
| 482 |
+
# Retrieve using ONLY the latest query
|
| 483 |
state = merge_state(state, retrieve_node(state))
|
| 484 |
else:
|
| 485 |
+
# No file: retrieve using latest query only
|
| 486 |
state = merge_state(initial_state, retrieve_node(initial_state))
|
| 487 |
|
| 488 |
# Generate response with streaming
|
| 489 |
+
# The generator can optionally use conversation_context for better responses
|
| 490 |
sources_collected = None
|
| 491 |
accumulated_response = "" if output_format == "gradio" else None
|
| 492 |
|