mtyrrell commited on
Commit
f5bde8f
·
1 Parent(s): 8a344c6
Files changed (4) hide show
  1. app/main.py +140 -373
  2. app/models.py +4 -18
  3. app/nodes.py +132 -220
  4. app/utils.py +21 -24
app/main.py CHANGED
@@ -1,93 +1,74 @@
1
- #CHATFED_ORCHESTRATOR
2
  import gradio as gr
3
- from fastapi import FastAPI, UploadFile, File, Form, Request
4
- from fastapi.responses import StreamingResponse, JSONResponse
5
  from langserve import add_routes
6
  from langgraph.graph import StateGraph, START, END
7
- from typing import Optional
8
  import uvicorn
9
  import os
10
  from datetime import datetime
11
  import logging
12
- from contextlib import asynccontextmanager
13
  from langchain_core.runnables import RunnableLambda
14
  import asyncio
15
  import json
16
- from functools import partial
17
  import base64
18
 
19
  from utils import getconfig
20
- from nodes import detect_file_type_node, ingest_node, geojson_direct_result_node, retrieve_node, generate_node_streaming, route_workflow, process_query_streaming
 
 
 
21
  from models import GraphState, ChatUIInput, ChatUIFileInput
22
 
23
  config = getconfig("params.cfg")
24
- RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
25
- GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
26
- INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
27
- GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
28
- MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
29
-
30
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
  logger = logging.getLogger(__name__)
32
 
33
 
34
-
35
  #----------------------------------------
36
- # CORE WORKFLOW GRAPH
37
  #----------------------------------------
38
- # graph setup
 
 
39
  workflow = StateGraph(GraphState)
40
  workflow.add_node("detect_file_type", detect_file_type_node)
41
  workflow.add_node("ingest", ingest_node)
42
  workflow.add_node("geojson_direct", geojson_direct_result_node)
43
  workflow.add_node("retrieve", retrieve_node)
44
- workflow.add_node("generate", generate_node_streaming) # Changed to generate_node_streaming
45
 
46
- # Add edges
47
  workflow.add_edge(START, "detect_file_type")
48
  workflow.add_edge("detect_file_type", "ingest")
49
 
50
- # Conditional routing after ingestion
51
  workflow.add_conditional_edges(
52
  "ingest",
53
  route_workflow,
54
- {
55
- "geojson_direct": "geojson_direct",
56
- "standard": "retrieve"
57
- }
58
  )
59
 
60
- # Standard workflow
61
  workflow.add_edge("retrieve", "generate")
62
  workflow.add_edge("generate", END)
63
-
64
- # GeoJSON direct workflow
65
  workflow.add_edge("geojson_direct", END)
66
 
67
  compiled_graph = workflow.compile()
 
68
 
69
 
 
 
 
 
70
  async def chatui_adapter(data):
71
- """Updated to return content without SSE formatting for LangServe"""
72
  try:
73
- # Handle both dict and Pydantic model input
74
- if hasattr(data, 'text'):
75
- text = data.text
76
- elif isinstance(data, dict) and 'text' in data:
77
- text = data['text']
78
- else:
79
- logger.error(f"Unexpected input structure: {data}")
80
- yield "Error: Invalid input format"
81
- return
82
 
83
- # Collect all content and sources
84
  full_response = ""
85
  sources_collected = None
86
 
87
- # Use the streaming function and return content directly (no SSE formatting)
88
  async for result in process_query_streaming(
89
  query=text,
90
- file_upload=None, # No file upload for text-only ChatUI
91
  reports_filter="",
92
  sources_filter="",
93
  subtype_filter="",
@@ -98,26 +79,19 @@ async def chatui_adapter(data):
98
  content = result.get("content", "")
99
 
100
  if result_type == "data":
101
- # Accumulate the response text
102
  full_response += content
103
  yield content
104
-
105
  elif result_type == "sources":
106
- # Store sources for later
107
  sources_collected = content
108
-
109
  elif result_type == "end":
110
- # Append sources to the final response if we have them
111
  if sources_collected:
112
  sources_text = "\n\n**Sources:**\n"
113
  for i, source in enumerate(sources_collected, 1):
114
  sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n"
115
  yield sources_text
116
-
117
  elif result_type == "error":
118
  yield f"Error: {content}"
119
  else:
120
- # Fallback for plain string results
121
  yield str(result)
122
 
123
  await asyncio.sleep(0)
@@ -127,146 +101,26 @@ async def chatui_adapter(data):
127
  yield f"Error: {str(e)}"
128
 
129
 
130
- async def process_query_streaming_with_file_content(
131
- query: str,
132
- file_content: Optional[bytes] = None,
133
- filename: Optional[str] = None,
134
- reports_filter: str = "",
135
- sources_filter: str = "",
136
- subtype_filter: str = "",
137
- year_filter: str = "",
138
- output_format: str = "structured"
139
- ):
140
- """
141
- Modified streaming function that accepts file_content directly instead of file_upload
142
- """
143
- start_time = datetime.now()
144
- session_id = f"chatui_{start_time.strftime('%Y%m%d_%H%M%S')}"
145
-
146
- try:
147
- # Process ingestion first (non-streaming)
148
- initial_state = {
149
- "query": query,
150
- "context": "",
151
- "ingestor_context": "",
152
- "result": "",
153
- "sources": [],
154
- "reports_filter": reports_filter or "",
155
- "sources_filter": sources_filter or "",
156
- "subtype_filter": subtype_filter or "",
157
- "year_filter": year_filter or "",
158
- "file_content": file_content,
159
- "filename": filename,
160
- "file_type": "unknown",
161
- "workflow_type": "standard",
162
- "metadata": {
163
- "session_id": session_id,
164
- "start_time": start_time.isoformat(),
165
- "has_file_attachment": file_content is not None
166
- }
167
- }
168
-
169
- # Detect file type - merge the returned state with initial state
170
- state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
171
-
172
- # Ingest if file provided - merge the returned state
173
- state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
174
-
175
- # Route workflow
176
- workflow_type = route_workflow(state_after_ingest)
177
-
178
- if workflow_type == "geojson_direct":
179
- # For GeoJSON, return direct result
180
- final_state = geojson_direct_result_node(state_after_ingest)
181
- if output_format == "structured":
182
- yield {"type": "data", "content": final_state["result"]}
183
- yield {"type": "end", "content": ""}
184
- else:
185
- yield final_state["result"]
186
- else:
187
- # For standard workflow, retrieve first - merge the returned state
188
- state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
189
-
190
- # Initialize variables for both output formats
191
- sources_collected = None
192
- accumulated_response = "" if output_format == "gradio" else None
193
-
194
- # Then stream generation
195
- async for partial_state in generate_node_streaming(state_after_retrieve):
196
- if "result" in partial_state:
197
- if output_format == "structured":
198
- yield {"type": "data", "content": partial_state["result"]}
199
- else:
200
- # Accumulate the content and yield the full accumulated response
201
- accumulated_response += partial_state["result"]
202
- yield accumulated_response
203
-
204
- # Collect sources for later
205
- if "sources" in partial_state:
206
- sources_collected = partial_state["sources"]
207
-
208
- # Handle sources based on output format
209
- if sources_collected:
210
- if output_format == "structured":
211
- yield {"type": "sources", "content": sources_collected}
212
- else:
213
- # Append sources to accumulated response
214
- sources_text = "\n\n**Sources:**\n"
215
- for i, source in enumerate(sources_collected, 1):
216
- if isinstance(source, dict):
217
- title = source.get('title', 'Unknown')
218
- link = source.get('link', '#')
219
- sources_text += f"{i}. [{title}]({link})\n"
220
- else:
221
- sources_text += f"{i}. {source}\n"
222
-
223
- accumulated_response += sources_text
224
- yield accumulated_response
225
-
226
- if output_format == "structured":
227
- yield {"type": "end", "content": ""}
228
-
229
- except Exception as e:
230
- logger.error(f"Streaming pipeline failed: {str(e)}")
231
- if output_format == "structured":
232
- yield {"type": "error", "content": f"Error: {str(e)}"}
233
- else:
234
- yield f"Error: {str(e)}"
235
-
236
-
237
  async def chatui_file_adapter(data):
238
- """New adapter for file uploads with streaming response"""
239
  try:
240
- logger.info(f"=== CHATUI FILE ADAPTER CALLED ===")
241
- logger.info(f"Input data type: {type(data)}")
242
 
243
- # Handle both dict and Pydantic model input
244
- if hasattr(data, 'text'):
245
- text = data.text
246
- files = getattr(data, 'files', None)
247
- elif isinstance(data, dict):
248
- text = data.get('text', '')
249
- files = data.get('files', None)
250
- else:
251
- logger.error(f"Unexpected input structure: {data}")
252
- yield "Error: Invalid input format"
253
- return
254
 
255
  logger.info(f"Text: {text[:100]}...")
256
  logger.info(f"Files present: {files is not None and len(files) > 0 if files else False}")
257
 
258
- # Process file if provided
259
  file_content = None
260
  filename = None
261
 
262
  if files and len(files) > 0:
263
- # Get the first file
264
  file_info = files[0]
265
  logger.info(f"Processing file: {file_info.get('name', 'unknown')}")
266
 
267
  if file_info.get('type') == 'base64' and file_info.get('content'):
268
  try:
269
- # Decode base64 content
270
  file_content = base64.b64decode(file_info['content'])
271
  filename = file_info.get('name', 'uploaded_file')
272
  logger.info(f"Decoded file: {filename}, size: {len(file_content)} bytes")
@@ -275,8 +129,8 @@ async def chatui_file_adapter(data):
275
  yield f"Error: Failed to decode uploaded file - {str(e)}"
276
  return
277
 
278
- # Use the modified streaming function that handles file content directly
279
- async for result in process_query_streaming_with_file_content(
280
  query=text,
281
  file_content=file_content,
282
  filename=filename,
@@ -284,7 +138,7 @@ async def chatui_file_adapter(data):
284
  sources_filter="",
285
  subtype_filter="",
286
  year_filter="",
287
- output_format="structured" # Use structured format for better control
288
  ):
289
  if isinstance(result, dict):
290
  result_type = result.get("type", "data")
@@ -292,9 +146,7 @@ async def chatui_file_adapter(data):
292
 
293
  if result_type == "data":
294
  yield content
295
-
296
  elif result_type == "sources":
297
- # Format sources nicely
298
  if content:
299
  sources_text = "\n\n**Sources:**\n"
300
  for i, source in enumerate(content, 1):
@@ -305,11 +157,9 @@ async def chatui_file_adapter(data):
305
  else:
306
  sources_text += f"{i}. {source}\n"
307
  yield sources_text
308
-
309
  elif result_type == "error":
310
  yield f"Error: {content}"
311
  else:
312
- # Fallback for plain string results
313
  yield str(result)
314
 
315
  await asyncio.sleep(0)
@@ -319,52 +169,11 @@ async def chatui_file_adapter(data):
319
  yield f"Error: {str(e)}"
320
 
321
 
322
- # GRADIO TEST UI
323
- def create_gradio_interface():
324
- with gr.Blocks(title="ChatFed Orchestrator") as demo:
325
- gr.Markdown("# ChatFed Orchestrator")
326
- gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`")
327
-
328
- with gr.Row():
329
- with gr.Column():
330
- query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
331
- file_input = gr.File(label="Upload Document (PDF/DOCX/GeoJSON)", file_types=[".pdf", ".docx", ".geojson", ".json"])
332
-
333
- with gr.Accordion("Filters (Optional)", open=False):
334
- reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
335
- sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
336
- subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
337
- year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
338
-
339
- submit_btn = gr.Button("Submit", variant="primary")
340
-
341
- with gr.Column():
342
- output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
343
-
344
- # Use streaming function
345
- submit_btn.click(
346
- fn=partial(process_query_streaming, output_format="gradio"),
347
- inputs=[query_input, file_input, reports_filter_input, sources_filter_input,
348
- subtype_filter_input, year_filter_input],
349
- outputs=output,
350
- show_progress="minimal"
351
- )
352
-
353
- return demo
354
 
355
- @asynccontextmanager
356
- async def lifespan(app: FastAPI):
357
- logger.info("ChatFed Orchestrator starting up...")
358
- yield
359
- logger.info("Orchestrator shutting down...")
360
-
361
- app = FastAPI(
362
- title="ChatFed Orchestrator",
363
- version="1.0.0",
364
- lifespan=lifespan,
365
- docs_url=None,
366
- redoc_url=None
367
- )
368
 
369
  @app.get("/health")
370
  async def health_check():
@@ -376,144 +185,33 @@ async def root():
376
  "message": "ChatFed Orchestrator API",
377
  "endpoints": {
378
  "health": "/health",
379
- "chatfed-ui-stream": "/chatfed-ui-stream",
380
- "chatfed-with-file": "/chatfed-with-file",
381
- "chatfed-with-file-stream": "/chatfed-with-file-stream", # New Langserve route
 
382
  }
383
  }
384
 
385
 
386
-
387
- # # FILE UPLOAD ADAPTER
388
- async def chatfed_with_file_adapter(
389
- query: str,
390
- file_content: Optional[bytes] = None,
391
- filename: Optional[str] = None,
392
- reports_filter: str = "",
393
- sources_filter: str = "",
394
- subtype_filter: str = "",
395
- year_filter: str = "",
396
- session_id: Optional[str] = None,
397
- user_id: Optional[str] = None
398
- ):
399
- """Async streaming adapter for file uploads"""
400
- try:
401
- # Use the same streaming logic as the working text endpoint
402
- start_time = datetime.now()
403
- if not session_id:
404
- current_session_id = f"api_{start_time.strftime('%Y%m%d_%H%M%S')}"
405
- else:
406
- current_session_id = session_id
407
-
408
- # Create initial state
409
- initial_state = {
410
- "query": query,
411
- "context": "",
412
- "ingestor_context": "",
413
- "result": "",
414
- "sources": [],
415
- "reports_filter": reports_filter or "",
416
- "sources_filter": sources_filter or "",
417
- "subtype_filter": subtype_filter or "",
418
- "year_filter": year_filter or "",
419
- "file_content": file_content,
420
- "filename": filename,
421
- "file_type": "unknown",
422
- "workflow_type": "standard",
423
- "metadata": {
424
- "session_id": current_session_id,
425
- "user_id": user_id,
426
- "start_time": start_time.isoformat(),
427
- "has_file_attachment": file_content is not None
428
- }
429
- }
430
-
431
- # Process non-streaming steps first
432
- state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
433
- state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
434
-
435
- # Route workflow
436
- workflow_type = route_workflow(state_after_ingest)
437
-
438
- if workflow_type == "geojson_direct":
439
- # For GeoJSON, return direct result
440
- final_state = geojson_direct_result_node(state_after_ingest)
441
- yield final_state["result"]
442
- else:
443
- # For standard workflow, retrieve first
444
- state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
445
-
446
- # Then stream generation
447
- async for partial_state in generate_node_streaming(state_after_retrieve):
448
- if "result" in partial_state:
449
- yield partial_state["result"]
450
- await asyncio.sleep(0) # Make it properly async
451
-
452
- except Exception as e:
453
- logger.error(f"File upload streaming failed: {str(e)}")
454
- yield f"Error: {str(e)}"
455
-
456
- # NON-STREAMING FILE UPLOAD
457
- # @app.post("/chatfed-with-file")
458
- # async def chatfed_with_file(
459
- # query: str = Form(...),
460
- # file: Optional[UploadFile] = File(None),
461
- # reports_filter: Optional[str] = Form(""),
462
- # sources_filter: Optional[str] = Form(""),
463
- # subtype_filter: Optional[str] = Form(""),
464
- # year_filter: Optional[str] = Form(""),
465
- # session_id: Optional[str] = Form(None),
466
- # user_id: Optional[str] = Form(None)
467
- # ):
468
- # """Endpoint for queries with optional file attachments + streaming"""
469
-
470
- # # Read file content first
471
- # file_content = None
472
- # filename = None
473
-
474
- # if file:
475
- # file_content = await file.read()
476
- # filename = file.filename
477
-
478
- # # Stream the response instead of collecting chunks
479
- # async def stream_generator():
480
- # async for chunk in chatfed_with_file_adapter(
481
- # query=query,
482
- # file_content=file_content,
483
- # filename=filename,
484
- # reports_filter=reports_filter,
485
- # sources_filter=sources_filter,
486
- # subtype_filter=subtype_filter,
487
- # year_filter=year_filter,
488
- # session_id=session_id,
489
- # user_id=user_id
490
- # ):
491
- # yield chunk
492
-
493
- # return StreamingResponse(
494
- # stream_generator(),
495
- # media_type="text/plain"
496
- # )
497
-
498
- # MAIN FILE UPLOAD STREAMING ENDPOINT
499
  @app.post("/chatfed-with-file")
500
- async def chatfed_with_file_stream(
501
  query: str = Form(...),
502
- file: Optional[UploadFile] = File(None),
503
- reports_filter: Optional[str] = Form(""),
504
- sources_filter: Optional[str] = Form(""),
505
- subtype_filter: Optional[str] = Form(""),
506
- year_filter: Optional[str] = Form(""),
507
- session_id: Optional[str] = Form(None),
508
- user_id: Optional[str] = Form(None)
509
  ):
510
- """File upload endpoint with proper SSE streaming for ChatUI"""
511
-
512
- logger.info(f"=== FILE UPLOAD ENDPOINT CALLED ===")
 
 
513
  logger.info(f"Query: {query[:100]}...")
514
  logger.info(f"File: {file.filename if file else 'None'}")
515
 
516
- # Read file content
517
  file_content = None
518
  filename = None
519
 
@@ -522,11 +220,11 @@ async def chatfed_with_file_stream(
522
  filename = file.filename
523
 
524
  async def sse_generator():
525
- """Generate Server-Sent Events format for ChatUI"""
526
  try:
527
  token_id = 0
528
 
529
- async for chunk in chatfed_with_file_adapter(
530
  query=query,
531
  file_content=file_content,
532
  filename=filename,
@@ -534,24 +232,34 @@ async def chatfed_with_file_stream(
534
  sources_filter=sources_filter,
535
  subtype_filter=subtype_filter,
536
  year_filter=year_filter,
537
- session_id=session_id,
538
- user_id=user_id
539
  ):
540
- if isinstance(chunk, str) and chunk.strip():
541
- # Format as SSE data that ChatUI expects
542
- token_data = {
543
- "token": chunk,
544
- "text": chunk,
545
- "content": chunk
546
- }
547
- yield f"data: {json.dumps(token_data)}\n\n"
548
- token_id += 1
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  await asyncio.sleep(0)
551
 
552
- # Send end marker
553
  yield f"data: [DONE]\n\n"
554
- logger.info("Generator stream ended")
555
 
556
  except Exception as e:
557
  logger.error(f"SSE generation error: {str(e)}")
@@ -565,11 +273,15 @@ async def chatfed_with_file_stream(
565
  "Cache-Control": "no-cache",
566
  "Connection": "keep-alive",
567
  "Access-Control-Allow-Origin": "*",
568
- "Access-Control-Allow-Headers": "*",
569
  }
570
  )
571
 
572
- # Add the existing text-only Langserve route
 
 
 
 
 
573
  add_routes(
574
  app,
575
  RunnableLambda(chatui_adapter),
@@ -580,7 +292,7 @@ add_routes(
580
  enable_public_trace_link_endpoint=True,
581
  )
582
 
583
- # Add the new file upload Langserve route
584
  add_routes(
585
  app,
586
  RunnableLambda(chatui_file_adapter),
@@ -591,17 +303,72 @@ add_routes(
591
  enable_public_trace_link_endpoint=True,
592
  )
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  if __name__ == "__main__":
595
- # Create Gradio interface
596
  demo = create_gradio_interface()
597
-
598
- # Mount Gradio app to FastAPI
599
  app = gr.mount_gradio_app(app, demo, path="/gradio")
600
 
601
  host = os.getenv("HOST", "0.0.0.0")
602
  port = int(os.getenv("PORT", "7860"))
603
 
604
- logger.info(f"Starting FastAPI server on {host}:{port}")
605
- # logger.info(f"Gradio UI available at: http://{host}:{port}/gradio")
 
606
 
607
  uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)
 
 
1
  import gradio as gr
2
+ from fastapi import FastAPI, UploadFile, File, Form
3
+ from fastapi.responses import StreamingResponse
4
  from langserve import add_routes
5
  from langgraph.graph import StateGraph, START, END
 
6
  import uvicorn
7
  import os
8
  from datetime import datetime
9
  import logging
 
10
  from langchain_core.runnables import RunnableLambda
11
  import asyncio
12
  import json
 
13
  import base64
14
 
15
  from utils import getconfig
16
+ from nodes import (
17
+ detect_file_type_node, ingest_node, geojson_direct_result_node,
18
+ retrieve_node, generate_node_streaming, route_workflow, process_query_streaming
19
+ )
20
  from models import GraphState, ChatUIInput, ChatUIFileInput
21
 
22
  config = getconfig("params.cfg")
 
 
 
 
 
 
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
  logger = logging.getLogger(__name__)
25
 
26
 
 
27
  #----------------------------------------
28
+ # LANGGRAPH WORKFLOW SETUP
29
  #----------------------------------------
30
+ # NOTE: Currently using manual node chaining for stability.
31
+ # This graph is prepared for future agentic workflow expansion.
32
+
33
  workflow = StateGraph(GraphState)
34
  workflow.add_node("detect_file_type", detect_file_type_node)
35
  workflow.add_node("ingest", ingest_node)
36
  workflow.add_node("geojson_direct", geojson_direct_result_node)
37
  workflow.add_node("retrieve", retrieve_node)
38
+ workflow.add_node("generate", generate_node_streaming)
39
 
 
40
  workflow.add_edge(START, "detect_file_type")
41
  workflow.add_edge("detect_file_type", "ingest")
42
 
 
43
  workflow.add_conditional_edges(
44
  "ingest",
45
  route_workflow,
46
+ {"geojson_direct": "geojson_direct", "standard": "retrieve"}
 
 
 
47
  )
48
 
 
49
  workflow.add_edge("retrieve", "generate")
50
  workflow.add_edge("generate", END)
 
 
51
  workflow.add_edge("geojson_direct", END)
52
 
53
  compiled_graph = workflow.compile()
54
+ # Future: Replace manual node chaining with: compiled_graph.astream(initial_state)
55
 
56
 
57
+ #----------------------------------------
58
+ # CHATUI ADAPTERS
59
+ #----------------------------------------
60
+
61
  async def chatui_adapter(data):
62
+ """Text-only adapter for ChatUI"""
63
  try:
64
+ text = data.text if hasattr(data, 'text') else data.get('text', '')
 
 
 
 
 
 
 
 
65
 
 
66
  full_response = ""
67
  sources_collected = None
68
 
 
69
  async for result in process_query_streaming(
70
  query=text,
71
+ file_upload=None,
72
  reports_filter="",
73
  sources_filter="",
74
  subtype_filter="",
 
79
  content = result.get("content", "")
80
 
81
  if result_type == "data":
 
82
  full_response += content
83
  yield content
 
84
  elif result_type == "sources":
 
85
  sources_collected = content
 
86
  elif result_type == "end":
 
87
  if sources_collected:
88
  sources_text = "\n\n**Sources:**\n"
89
  for i, source in enumerate(sources_collected, 1):
90
  sources_text += f"{i}. [{source.get('title', 'Unknown')}]({source.get('link', '#')})\n"
91
  yield sources_text
 
92
  elif result_type == "error":
93
  yield f"Error: {content}"
94
  else:
 
95
  yield str(result)
96
 
97
  await asyncio.sleep(0)
 
101
  yield f"Error: {str(e)}"
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  async def chatui_file_adapter(data):
105
+ """File upload adapter for ChatUI"""
106
  try:
107
+ logger.info("=== CHATUI FILE ADAPTER CALLED ===")
 
108
 
109
+ text = data.text if hasattr(data, 'text') else data.get('text', '')
110
+ files = getattr(data, 'files', None) if hasattr(data, 'files') else data.get('files', None)
 
 
 
 
 
 
 
 
 
111
 
112
  logger.info(f"Text: {text[:100]}...")
113
  logger.info(f"Files present: {files is not None and len(files) > 0 if files else False}")
114
 
 
115
  file_content = None
116
  filename = None
117
 
118
  if files and len(files) > 0:
 
119
  file_info = files[0]
120
  logger.info(f"Processing file: {file_info.get('name', 'unknown')}")
121
 
122
  if file_info.get('type') == 'base64' and file_info.get('content'):
123
  try:
 
124
  file_content = base64.b64decode(file_info['content'])
125
  filename = file_info.get('name', 'uploaded_file')
126
  logger.info(f"Decoded file: {filename}, size: {len(file_content)} bytes")
 
129
  yield f"Error: Failed to decode uploaded file - {str(e)}"
130
  return
131
 
132
+ # Use the unified streaming function
133
+ async for result in process_query_streaming(
134
  query=text,
135
  file_content=file_content,
136
  filename=filename,
 
138
  sources_filter="",
139
  subtype_filter="",
140
  year_filter="",
141
+ output_format="structured"
142
  ):
143
  if isinstance(result, dict):
144
  result_type = result.get("type", "data")
 
146
 
147
  if result_type == "data":
148
  yield content
 
149
  elif result_type == "sources":
 
150
  if content:
151
  sources_text = "\n\n**Sources:**\n"
152
  for i, source in enumerate(content, 1):
 
157
  else:
158
  sources_text += f"{i}. {source}\n"
159
  yield sources_text
 
160
  elif result_type == "error":
161
  yield f"Error: {content}"
162
  else:
 
163
  yield str(result)
164
 
165
  await asyncio.sleep(0)
 
169
  yield f"Error: {str(e)}"
170
 
171
 
172
+ #----------------------------------------
173
+ # FASTAPI SETUP
174
+ #----------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ app = FastAPI(title="ChatFed Orchestrator", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  @app.get("/health")
179
  async def health_check():
 
185
  "message": "ChatFed Orchestrator API",
186
  "endpoints": {
187
  "health": "/health",
188
+ "chatfed-ui-stream": "/chatfed-ui-stream (LangServe)",
189
+ "chatfed-with-file": "/chatfed-with-file (FastAPI/SSE)",
190
+ "chatfed-with-file-stream": "/chatfed-with-file-stream (LangServe)",
191
+ "gradio": "/gradio"
192
  }
193
  }
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  @app.post("/chatfed-with-file")
197
+ async def chatfed_with_file_endpoint(
198
  query: str = Form(...),
199
+ file: UploadFile = File(None),
200
+ reports_filter: str = Form(""),
201
+ sources_filter: str = Form(""),
202
+ subtype_filter: str = Form(""),
203
+ year_filter: str = Form(""),
204
+ session_id: str = Form(None),
205
+ user_id: str = Form(None)
206
  ):
207
+ """
208
+ File upload endpoint with SSE streaming for ChatUI.
209
+ This endpoint is currently used by ChatUI's fileUploadUrl config.
210
+ """
211
+ logger.info("=== FILE UPLOAD ENDPOINT CALLED ===")
212
  logger.info(f"Query: {query[:100]}...")
213
  logger.info(f"File: {file.filename if file else 'None'}")
214
 
 
215
  file_content = None
216
  filename = None
217
 
 
220
  filename = file.filename
221
 
222
  async def sse_generator():
223
+ """Generate Server-Sent Events for ChatUI"""
224
  try:
225
  token_id = 0
226
 
227
+ async for chunk in process_query_streaming(
228
  query=query,
229
  file_content=file_content,
230
  filename=filename,
 
232
  sources_filter=sources_filter,
233
  subtype_filter=subtype_filter,
234
  year_filter=year_filter,
235
+ output_format="structured"
 
236
  ):
237
+ if isinstance(chunk, dict):
238
+ chunk_type = chunk.get("type", "data")
239
+ content = chunk.get("content", "")
 
 
 
 
 
 
240
 
241
+ if chunk_type == "data" and content:
242
+ token_data = {"token": content, "text": content, "content": content}
243
+ yield f"data: {json.dumps(token_data)}\n\n"
244
+ token_id += 1
245
+ elif chunk_type == "sources" and content:
246
+ # Format sources for display
247
+ sources_text = "\n\n**Sources:**\n"
248
+ for i, source in enumerate(content, 1):
249
+ if isinstance(source, dict):
250
+ title = source.get('title', 'Unknown')
251
+ link = source.get('link', '#')
252
+ sources_text += f"{i}. [{title}]({link})\n"
253
+ token_data = {"token": sources_text, "text": sources_text, "content": sources_text}
254
+ yield f"data: {json.dumps(token_data)}\n\n"
255
+ elif chunk_type == "error":
256
+ error_data = {"error": content}
257
+ yield f"data: {json.dumps(error_data)}\n\n"
258
+
259
  await asyncio.sleep(0)
260
 
 
261
  yield f"data: [DONE]\n\n"
262
+ logger.info("SSE stream completed")
263
 
264
  except Exception as e:
265
  logger.error(f"SSE generation error: {str(e)}")
 
273
  "Cache-Control": "no-cache",
274
  "Connection": "keep-alive",
275
  "Access-Control-Allow-Origin": "*",
 
276
  }
277
  )
278
 
279
+
280
+ #----------------------------------------
281
+ # LANGSERVE ROUTES
282
+ #----------------------------------------
283
+
284
+ # Text-only endpoint
285
  add_routes(
286
  app,
287
  RunnableLambda(chatui_adapter),
 
292
  enable_public_trace_link_endpoint=True,
293
  )
294
 
295
+ # File upload endpoint (LangServe version for future migration)
296
  add_routes(
297
  app,
298
  RunnableLambda(chatui_file_adapter),
 
303
  enable_public_trace_link_endpoint=True,
304
  )
305
 
306
+
307
+ #----------------------------------------
308
+ # GRADIO INTERFACE
309
+ #----------------------------------------
310
+
311
+ def create_gradio_interface():
312
+ with gr.Blocks(title="ChatFed Orchestrator") as demo:
313
+ gr.Markdown("# ChatFed Orchestrator")
314
+ gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context.")
315
+
316
+ with gr.Row():
317
+ with gr.Column():
318
+ query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
319
+ file_input = gr.File(
320
+ label="Upload Document (PDF/DOCX/GeoJSON)",
321
+ file_types=[".pdf", ".docx", ".geojson", ".json"]
322
+ )
323
+
324
+ with gr.Accordion("Filters (Optional)", open=False):
325
+ reports_filter = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
326
+ sources_filter = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
327
+ subtype_filter = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
328
+ year_filter = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
329
+
330
+ submit_btn = gr.Button("Submit", variant="primary")
331
+
332
+ with gr.Column():
333
+ output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
334
+
335
+ async def gradio_handler(query, file, reports, sources, subtype, year):
336
+ """Handler for Gradio interface"""
337
+ result = ""
338
+ async for chunk in process_query_streaming(
339
+ query=query,
340
+ file_upload=file,
341
+ reports_filter=reports,
342
+ sources_filter=sources,
343
+ subtype_filter=subtype,
344
+ year_filter=year,
345
+ output_format="gradio"
346
+ ):
347
+ result = chunk # Each chunk is the full accumulated text
348
+ yield result
349
+
350
+ submit_btn.click(
351
+ fn=gradio_handler,
352
+ inputs=[query_input, file_input, reports_filter, sources_filter, subtype_filter, year_filter],
353
+ outputs=output,
354
+ )
355
+
356
+ return demo
357
+
358
+
359
+ #----------------------------------------
360
+ # MAIN
361
+ #----------------------------------------
362
+
363
  if __name__ == "__main__":
 
364
  demo = create_gradio_interface()
 
 
365
  app = gr.mount_gradio_app(app, demo, path="/gradio")
366
 
367
  host = os.getenv("HOST", "0.0.0.0")
368
  port = int(os.getenv("PORT", "7860"))
369
 
370
+ logger.info(f"Starting ChatFed Orchestrator on {host}:{port}")
371
+ logger.info(f"Gradio UI: http://{host}:{port}/gradio")
372
+ logger.info(f"API Docs: http://{host}:{port}/docs")
373
 
374
  uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)
app/models.py CHANGED
@@ -1,14 +1,14 @@
1
- # Models
2
  from typing import Optional, Dict, Any, List
3
  from typing_extensions import TypedDict
4
  from pydantic import BaseModel
5
 
6
  class GraphState(TypedDict):
 
7
  query: str
8
  context: str
9
  ingestor_context: str
10
  result: str
11
- sources: Optional[List[Dict[str, str]]] # Added for ChatUI sources
12
  reports_filter: str
13
  sources_filter: str
14
  subtype_filter: str
@@ -19,26 +19,12 @@ class GraphState(TypedDict):
19
  file_type: Optional[str]
20
  workflow_type: Optional[str] # 'standard' or 'geojson_direct'
21
 
22
- class ChatFedInput(TypedDict):
23
- query: str
24
- reports_filter: Optional[str]
25
- sources_filter: Optional[str]
26
- subtype_filter: Optional[str]
27
- year_filter: Optional[str]
28
- session_id: Optional[str]
29
- user_id: Optional[str]
30
- file_content: Optional[bytes]
31
- filename: Optional[str]
32
-
33
- class ChatFedOutput(TypedDict):
34
- result: str
35
- metadata: Dict[str, Any]
36
-
37
  class ChatUIInput(BaseModel):
 
38
  text: str
39
 
40
- # New model for file upload support
41
  class ChatUIFileInput(BaseModel):
 
42
  text: str
43
  files: Optional[List[Dict[str, Any]]] = None
44
 
 
 
1
  from typing import Optional, Dict, Any, List
2
  from typing_extensions import TypedDict
3
  from pydantic import BaseModel
4
 
5
  class GraphState(TypedDict):
6
+ """State object passed through LangGraph workflow"""
7
  query: str
8
  context: str
9
  ingestor_context: str
10
  result: str
11
+ sources: List[Dict[str, str]] # Always present, no Optional needed
12
  reports_filter: str
13
  sources_filter: str
14
  subtype_filter: str
 
19
  file_type: Optional[str]
20
  workflow_type: Optional[str] # 'standard' or 'geojson_direct'
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class ChatUIInput(BaseModel):
23
+ """Input model for text-only ChatUI requests"""
24
  text: str
25
 
 
26
  class ChatUIFileInput(BaseModel):
27
+ """Input model for ChatUI requests with file attachments"""
28
  text: str
29
  files: Optional[List[Dict[str, Any]]] = None
30
 
app/nodes.py CHANGED
@@ -1,36 +1,32 @@
1
- from utils import detect_file_type, convert_context_to_list
2
  from models import GraphState
3
  from datetime import datetime
4
  import tempfile
5
  import os
6
  from gradio_client import Client, file
7
  import logging
8
- from utils import getconfig
9
  import dotenv
10
- from typing_extensions import TypedDict
11
  import httpx
12
  import json
13
- from typing import Generator
14
 
15
  dotenv.load_dotenv()
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
19
  config = getconfig("params.cfg")
20
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
21
  GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
22
  INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
23
  GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
24
- MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
25
 
26
- ingestor_url = INGESTOR
27
- retriever_url = RETRIEVER
28
- generator_url = GENERATOR
29
- geojson_ingestor_url = GEOJSON_INGESTOR
30
 
31
- # CORE PROCESSING NODES
32
  #----------------------------------------
33
- # File type detection node
 
 
34
  def detect_file_type_node(state: GraphState) -> GraphState:
35
  """Detect file type and determine workflow"""
36
  file_type = "unknown"
@@ -38,12 +34,7 @@ def detect_file_type_node(state: GraphState) -> GraphState:
38
 
39
  if state.get("file_content") and state.get("filename"):
40
  file_type = detect_file_type(state["filename"], state["file_content"])
41
-
42
- # Determine workflow based on file type
43
- if file_type == "geojson":
44
- workflow_type = "geojson_direct"
45
- else:
46
- workflow_type = "standard"
47
 
48
  metadata = state.get("metadata", {})
49
  metadata.update({
@@ -57,12 +48,11 @@ def detect_file_type_node(state: GraphState) -> GraphState:
57
  "metadata": metadata
58
  }
59
 
60
- # Module functions
61
  def ingest_node(state: GraphState) -> GraphState:
62
  """Process file through appropriate ingestor based on file type"""
63
  start_time = datetime.now()
64
 
65
- # If no file provided, skip this step
66
  if not state.get("file_content") or not state.get("filename"):
67
  logger.info("No file provided, skipping ingestion")
68
  return {"ingestor_context": "", "metadata": state.get("metadata", {})}
@@ -72,35 +62,23 @@ def ingest_node(state: GraphState) -> GraphState:
72
 
73
  try:
74
  # Choose ingestor based on file type
75
- if file_type == "geojson":
76
- ingestor_url = GEOJSON_INGESTOR
77
- logger.info(f"Using GeoJSON ingestor: {ingestor_url}")
78
- else:
79
- ingestor_url = INGESTOR
80
- logger.info(f"Using standard ingestor: {ingestor_url}")
81
 
82
  client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN"))
83
 
84
- # Create a temporary file to upload
85
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
86
  tmp_file.write(state["file_content"])
87
  tmp_file_path = tmp_file.name
88
 
89
  try:
90
- # Call the ingestor's ingest endpoint
91
- ingestor_context = client.predict(
92
- file(tmp_file_path),
93
- api_name="/ingest"
94
- )
95
-
96
  logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
97
 
98
- # Handle error cases
99
  if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
100
  raise Exception(ingestor_context)
101
-
102
  finally:
103
- # Clean up temporary file
104
  os.unlink(tmp_file_path)
105
 
106
  duration = (datetime.now() - start_time).total_seconds()
@@ -112,10 +90,7 @@ def ingest_node(state: GraphState) -> GraphState:
112
  "ingestor_used": ingestor_url
113
  })
114
 
115
- return {
116
- "ingestor_context": ingestor_context,
117
- "metadata": metadata
118
- }
119
 
120
  except Exception as e:
121
  duration = (datetime.now() - start_time).total_seconds()
@@ -129,13 +104,12 @@ def ingest_node(state: GraphState) -> GraphState:
129
  })
130
  return {"ingestor_context": "", "metadata": metadata}
131
 
 
132
  def geojson_direct_result_node(state: GraphState) -> GraphState:
133
- """For GeoJSON files, return ingestor results directly without retrieval/generation"""
134
  logger.info("Processing GeoJSON file - returning direct results")
135
 
136
  ingestor_context = state.get("ingestor_context", "")
137
-
138
- # For GeoJSON files, the ingestor result is the final result
139
  result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
140
 
141
  metadata = state.get("metadata", {})
@@ -144,12 +118,11 @@ def geojson_direct_result_node(state: GraphState) -> GraphState:
144
  "result_length": len(result)
145
  })
146
 
147
- return {
148
- "result": result,
149
- "metadata": metadata
150
- }
151
 
152
  def retrieve_node(state: GraphState) -> GraphState:
 
153
  start_time = datetime.now()
154
  logger.info(f"Retrieval: {state['query'][:50]}...")
155
 
@@ -187,38 +160,29 @@ def retrieve_node(state: GraphState) -> GraphState:
187
  return {"context": "", "metadata": metadata}
188
 
189
 
190
-
191
-
192
- # MAIN STREAMING GENERATOR
193
  async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
194
- """Streaming version that calls generator's FastAPI endpoint"""
195
  start_time = datetime.now()
196
  logger.info(f"Generation (streaming): {state['query'][:50]}...")
197
 
198
  try:
199
- # Get MAX_CONTEXT_CHARS at the beginning so it's available throughout the function
200
- MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
201
- logger.info(f"Using MAX_CONTEXT_CHARS: {MAX_CONTEXT_CHARS}")
202
-
203
- # Combine retriever context with ingestor context
204
  retrieved_context = state.get("context", "")
205
  ingestor_context = state.get("ingestor_context", "")
206
 
207
- logger.info(f"Original context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}")
208
 
209
- # Convert contexts to list format expected by generator
210
  context_list = []
211
  total_context_chars = 0
212
 
213
  if ingestor_context:
214
- # Truncate ingestor context if it's too long
215
- if len(ingestor_context) > MAX_CONTEXT_CHARS:
216
- logger.warning(f"Truncating ingestor context from {len(ingestor_context)} to {MAX_CONTEXT_CHARS} characters")
217
- truncated_ingestor = ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]"
218
- else:
219
- truncated_ingestor = ingestor_context
220
 
221
- # Add ingestor context
222
  context_list.append({
223
  "answer": truncated_ingestor,
224
  "answer_metadata": {
@@ -231,154 +195,101 @@ async def generate_node_streaming(state: GraphState) -> Generator[GraphState, No
231
  total_context_chars += len(truncated_ingestor)
232
 
233
  if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS:
234
- # Convert retrieved context to list and add
235
  retrieved_list = convert_context_to_list(retrieved_context)
236
-
237
- # Add retrieved context items until we hit the limit
238
  remaining_chars = MAX_CONTEXT_CHARS - total_context_chars
 
239
  for item in retrieved_list:
240
  item_text = item.get("answer", "")
241
  if len(item_text) <= remaining_chars:
242
  context_list.append(item)
243
  remaining_chars -= len(item_text)
244
  else:
245
- # Truncate this item and stop
246
- if remaining_chars > 100: # Only add if we have meaningful space left
247
  item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]"
248
  context_list.append(item)
249
  break
250
 
251
- # Calculate final context size
252
  final_context_size = sum(len(item.get("answer", "")) for item in context_list)
253
  logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})")
254
 
255
- # Prepare the request payload
256
- payload = {
257
- "query": state["query"],
258
- "context": context_list
259
- }
260
 
261
- # Determine generator URL - handle both Hugging Face and direct URLs
262
  generator_url = GENERATOR
263
-
264
  if not generator_url.startswith('http'):
265
- # Allows for easy specification of space in config (converts to URL)
266
- # Replace '/' with '-' for Hugging Face space URLs
267
- # Force the replacement to ensure it works
268
- space_name = generator_url.replace('/', '-').replace('_', '-')
269
  generator_url = f"https://{space_name}.hf.space"
270
-
271
-
272
- # Try FastAPI endpoint first, fallback to Gradio if needed
273
- fastapi_success = False
274
 
275
- try:
276
- # Make streaming request to generator's FastAPI endpoint
277
- async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
 
 
 
 
 
 
 
278
 
279
- async with client.stream(
280
- "POST",
281
- f"{generator_url}/generate/stream",
282
- json=payload,
283
- headers={"Content-Type": "application/json"}
284
- ) as response:
285
- if response.status_code != 200:
286
- error_text = await response.aread()
287
- raise Exception(f"FastAPI endpoint returned status {response.status_code}")
288
-
289
- current_text = ""
290
- sources = None
291
- event_type = None
292
 
293
- async for line in response.aiter_lines():
294
- if not line.strip():
295
- continue
 
 
296
 
297
- # Parse SSE format
298
- if line.startswith("event: "):
299
- event_type = line[7:].strip()
300
- continue
301
- elif line.startswith("data: "):
302
- data_content = line[6:].strip()
303
-
304
- if event_type == "data":
305
- # Text chunk
306
- try:
307
- chunk = json.loads(data_content)
308
- if isinstance(chunk, str):
309
- current_text += chunk
310
-
311
- metadata = state.get("metadata", {})
312
- metadata.update({
313
- "generation_duration": (datetime.now() - start_time).total_seconds(),
314
- "result_length": len(current_text),
315
- "generation_success": True,
316
- "streaming": True,
317
- "generator_type": "fastapi",
318
- "context_chars_used": final_context_size
319
- })
320
-
321
- yield {
322
- "result": chunk, # Send only the new chunk
323
- "metadata": metadata
324
- }
325
- except json.JSONDecodeError:
326
- # Handle plain text chunks
327
- current_text += data_content
328
-
329
- metadata = state.get("metadata", {})
330
- metadata.update({
331
- "generation_duration": (datetime.now() - start_time).total_seconds(),
332
- "result_length": len(current_text),
333
- "generation_success": True,
334
- "streaming": True,
335
- "generator_type": "fastapi",
336
- "context_chars_used": final_context_size
337
- })
338
-
339
- yield {
340
- "result": data_content,
341
- "metadata": metadata
342
- }
343
-
344
- elif event_type == "sources":
345
- # Sources data
346
- try:
347
- sources_data = json.loads(data_content)
348
- sources = sources_data.get("sources", [])
349
-
350
- # Update state with sources
351
- metadata = state.get("metadata", {})
352
- metadata.update({
353
- "sources_received": True,
354
- "sources_count": len(sources)
355
- })
356
-
357
- yield {
358
- "sources": sources,
359
- "metadata": metadata
360
- }
361
- except json.JSONDecodeError:
362
- logger.warning(f"Failed to parse sources data: {data_content}")
363
 
364
- elif event_type == "end":
365
- # Stream ended
366
- logger.info("Generator stream ended")
367
- fastapi_success = True
368
- break
 
 
 
369
 
370
- elif event_type == "error":
371
- # Error occurred
372
- try:
373
- error_data = json.loads(data_content)
374
- raise Exception(error_data.get("error", "Unknown error"))
375
- except json.JSONDecodeError:
376
- raise Exception(data_content)
377
-
378
- except Exception as fastapi_error:
379
- # Handle FastAPI-specific errors
380
- logger.warning(f"FastAPI endpoint failed: {str(fastapi_error)}")
381
- raise fastapi_error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  except Exception as e:
384
  duration = (datetime.now() - start_time).total_seconds()
@@ -393,27 +304,38 @@ async def generate_node_streaming(state: GraphState) -> Generator[GraphState, No
393
  })
394
  yield {"result": f"Error: {str(e)}", "metadata": metadata}
395
 
396
- # Conditional routing function
397
- def route_workflow(state: GraphState) -> str:
398
- """Route to appropriate workflow based on file type"""
399
- workflow_type = state.get("workflow_type", "standard")
400
- return workflow_type
401
 
 
 
 
402
 
403
 
 
 
 
404
 
405
- async def process_query_streaming(query: str, file_upload, reports_filter: str = "", sources_filter: str = "",
406
- subtype_filter: str = "", year_filter: str = "",
407
- output_format: str = "structured"):
 
 
 
 
 
 
 
 
408
  """
409
- Unified streaming function that yields partial results
410
 
411
  Args:
412
- output_format: "structured" for dict format, "gradio" for plain text format
 
 
 
 
413
  """
414
- file_content = None
415
- filename = None
416
-
417
  if file_upload is not None:
418
  try:
419
  with open(file_upload.name, 'rb') as f:
@@ -429,10 +351,10 @@ async def process_query_streaming(query: str, file_upload, reports_filter: str =
429
  return
430
 
431
  start_time = datetime.now()
432
- session_id = f"gradio_{start_time.strftime('%Y%m%d_%H%M%S')}"
433
 
434
  try:
435
- # Process ingestion first (non-streaming)
436
  initial_state = {
437
  "query": query,
438
  "context": "",
@@ -454,51 +376,41 @@ async def process_query_streaming(query: str, file_upload, reports_filter: str =
454
  }
455
  }
456
 
457
- # Detect file type - merge the returned state with initial state
458
- state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
459
-
460
- # Ingest if file provided - merge the returned state
461
- state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
462
 
463
- # Route workflow
464
- workflow_type = route_workflow(state_after_ingest)
465
 
466
  if workflow_type == "geojson_direct":
467
- # For GeoJSON, return direct result
468
- final_state = geojson_direct_result_node(state_after_ingest)
469
  if output_format == "structured":
470
  yield {"type": "data", "content": final_state["result"]}
471
  yield {"type": "end", "content": ""}
472
  else:
473
  yield final_state["result"]
474
  else:
475
- # For standard workflow, retrieve first - merge the returned state
476
- state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
477
 
478
- # Initialize variables for both output formats
479
  sources_collected = None
480
  accumulated_response = "" if output_format == "gradio" else None
481
 
482
- # Then stream generation
483
- async for partial_state in generate_node_streaming(state_after_retrieve):
484
  if "result" in partial_state:
485
  if output_format == "structured":
486
  yield {"type": "data", "content": partial_state["result"]}
487
  else:
488
- # Accumulate the content and yield the full accumulated response
489
  accumulated_response += partial_state["result"]
490
  yield accumulated_response
491
 
492
- # Collect sources for later
493
  if "sources" in partial_state:
494
  sources_collected = partial_state["sources"]
495
 
496
- # Handle sources based on output format
497
  if sources_collected:
498
  if output_format == "structured":
499
  yield {"type": "sources", "content": sources_collected}
500
  else:
501
- # Append sources to accumulated response
502
  sources_text = "\n\n**Sources:**\n"
503
  for i, source in enumerate(sources_collected, 1):
504
  if isinstance(source, dict):
 
1
+ from utils import detect_file_type, convert_context_to_list, merge_state, getconfig
2
  from models import GraphState
3
  from datetime import datetime
4
  import tempfile
5
  import os
6
  from gradio_client import Client, file
7
  import logging
 
8
  import dotenv
 
9
  import httpx
10
  import json
11
+ from typing import Generator, Optional
12
 
13
  dotenv.load_dotenv()
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Load config once at module level
18
  config = getconfig("params.cfg")
19
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
20
  GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
21
  INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
22
  GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
23
+ MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
24
 
 
 
 
 
25
 
 
26
  #----------------------------------------
27
+ # LANGGRAPH NODE FUNCTIONS
28
+ #----------------------------------------
29
+
30
  def detect_file_type_node(state: GraphState) -> GraphState:
31
  """Detect file type and determine workflow"""
32
  file_type = "unknown"
 
34
 
35
  if state.get("file_content") and state.get("filename"):
36
  file_type = detect_file_type(state["filename"], state["file_content"])
37
+ workflow_type = "geojson_direct" if file_type == "geojson" else "standard"
 
 
 
 
 
38
 
39
  metadata = state.get("metadata", {})
40
  metadata.update({
 
48
  "metadata": metadata
49
  }
50
 
51
+
52
  def ingest_node(state: GraphState) -> GraphState:
53
  """Process file through appropriate ingestor based on file type"""
54
  start_time = datetime.now()
55
 
 
56
  if not state.get("file_content") or not state.get("filename"):
57
  logger.info("No file provided, skipping ingestion")
58
  return {"ingestor_context": "", "metadata": state.get("metadata", {})}
 
62
 
63
  try:
64
  # Choose ingestor based on file type
65
+ ingestor_url = GEOJSON_INGESTOR if file_type == "geojson" else INGESTOR
66
+ logger.info(f"Using ingestor: {ingestor_url}")
 
 
 
 
67
 
68
  client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN"))
69
 
70
+ # Create temporary file for upload
71
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
72
  tmp_file.write(state["file_content"])
73
  tmp_file_path = tmp_file.name
74
 
75
  try:
76
+ ingestor_context = client.predict(file(tmp_file_path), api_name="/ingest")
 
 
 
 
 
77
  logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
78
 
 
79
  if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
80
  raise Exception(ingestor_context)
 
81
  finally:
 
82
  os.unlink(tmp_file_path)
83
 
84
  duration = (datetime.now() - start_time).total_seconds()
 
90
  "ingestor_used": ingestor_url
91
  })
92
 
93
+ return {"ingestor_context": ingestor_context, "metadata": metadata}
 
 
 
94
 
95
  except Exception as e:
96
  duration = (datetime.now() - start_time).total_seconds()
 
104
  })
105
  return {"ingestor_context": "", "metadata": metadata}
106
 
107
+
108
  def geojson_direct_result_node(state: GraphState) -> GraphState:
109
+ """For GeoJSON files, return ingestor results directly"""
110
  logger.info("Processing GeoJSON file - returning direct results")
111
 
112
  ingestor_context = state.get("ingestor_context", "")
 
 
113
  result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
114
 
115
  metadata = state.get("metadata", {})
 
118
  "result_length": len(result)
119
  })
120
 
121
+ return {"result": result, "metadata": metadata}
122
+
 
 
123
 
124
  def retrieve_node(state: GraphState) -> GraphState:
125
+ """Retrieve relevant context from vector store"""
126
  start_time = datetime.now()
127
  logger.info(f"Retrieval: {state['query'][:50]}...")
128
 
 
160
  return {"context": "", "metadata": metadata}
161
 
162
 
 
 
 
163
  async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
164
+ """Streaming generation using generator's FastAPI endpoint"""
165
  start_time = datetime.now()
166
  logger.info(f"Generation (streaming): {state['query'][:50]}...")
167
 
168
  try:
169
+ # Combine contexts
 
 
 
 
170
  retrieved_context = state.get("context", "")
171
  ingestor_context = state.get("ingestor_context", "")
172
 
173
+ logger.info(f"Context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}")
174
 
175
+ # Build context list with truncation
176
  context_list = []
177
  total_context_chars = 0
178
 
179
  if ingestor_context:
180
+ truncated_ingestor = (
181
+ ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]"
182
+ if len(ingestor_context) > MAX_CONTEXT_CHARS
183
+ else ingestor_context
184
+ )
 
185
 
 
186
  context_list.append({
187
  "answer": truncated_ingestor,
188
  "answer_metadata": {
 
195
  total_context_chars += len(truncated_ingestor)
196
 
197
  if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS:
 
198
  retrieved_list = convert_context_to_list(retrieved_context)
 
 
199
  remaining_chars = MAX_CONTEXT_CHARS - total_context_chars
200
+
201
  for item in retrieved_list:
202
  item_text = item.get("answer", "")
203
  if len(item_text) <= remaining_chars:
204
  context_list.append(item)
205
  remaining_chars -= len(item_text)
206
  else:
207
+ if remaining_chars > 100:
 
208
  item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]"
209
  context_list.append(item)
210
  break
211
 
 
212
  final_context_size = sum(len(item.get("answer", "")) for item in context_list)
213
  logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})")
214
 
215
+ payload = {"query": state["query"], "context": context_list}
 
 
 
 
216
 
217
+ # Normalize generator URL
218
  generator_url = GENERATOR
 
219
  if not generator_url.startswith('http'):
220
+ space_name = generator_url.replace('/', '-')
 
 
 
221
  generator_url = f"https://{space_name}.hf.space"
 
 
 
 
222
 
223
+ # Stream from generator
224
+ async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
225
+ async with client.stream(
226
+ "POST",
227
+ f"{generator_url}/generate/stream",
228
+ json=payload,
229
+ headers={"Content-Type": "application/json"}
230
+ ) as response:
231
+ if response.status_code != 200:
232
+ raise Exception(f"Generator returned status {response.status_code}")
233
 
234
+ current_text = ""
235
+ sources = None
236
+ event_type = None
237
+
238
+ async for line in response.aiter_lines():
239
+ if not line.strip():
240
+ continue
 
 
 
 
 
 
241
 
242
+ if line.startswith("event: "):
243
+ event_type = line[7:].strip()
244
+ continue
245
+ elif line.startswith("data: "):
246
+ data_content = line[6:].strip()
247
 
248
+ if event_type == "data":
249
+ try:
250
+ chunk = json.loads(data_content)
251
+ if isinstance(chunk, str):
252
+ current_text += chunk
253
+ except json.JSONDecodeError:
254
+ current_text += data_content
255
+ chunk = data_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
+ metadata = state.get("metadata", {})
258
+ metadata.update({
259
+ "generation_duration": (datetime.now() - start_time).total_seconds(),
260
+ "result_length": len(current_text),
261
+ "generation_success": True,
262
+ "streaming": True,
263
+ "context_chars_used": final_context_size
264
+ })
265
 
266
+ yield {"result": chunk, "metadata": metadata}
267
+
268
+ elif event_type == "sources":
269
+ try:
270
+ sources_data = json.loads(data_content)
271
+ sources = sources_data.get("sources", [])
272
+
273
+ metadata = state.get("metadata", {})
274
+ metadata.update({
275
+ "sources_received": True,
276
+ "sources_count": len(sources)
277
+ })
278
+
279
+ yield {"sources": sources, "metadata": metadata}
280
+ except json.JSONDecodeError:
281
+ logger.warning(f"Failed to parse sources: {data_content}")
282
+
283
+ elif event_type == "end":
284
+ logger.info("Generator stream ended")
285
+ break
286
+
287
+ elif event_type == "error":
288
+ try:
289
+ error_data = json.loads(data_content)
290
+ raise Exception(error_data.get("error", "Unknown error"))
291
+ except json.JSONDecodeError:
292
+ raise Exception(data_content)
293
 
294
  except Exception as e:
295
  duration = (datetime.now() - start_time).total_seconds()
 
304
  })
305
  yield {"result": f"Error: {str(e)}", "metadata": metadata}
306
 
 
 
 
 
 
307
 
308
+ def route_workflow(state: GraphState) -> str:
309
+ """Conditional routing based on workflow type"""
310
+ return state.get("workflow_type", "standard")
311
 
312
 
313
+ #----------------------------------------
314
+ # UNIFIED STREAMING PROCESSOR
315
+ #----------------------------------------
316
 
317
+ async def process_query_streaming(
318
+ query: str,
319
+ file_upload=None,
320
+ file_content: Optional[bytes] = None,
321
+ filename: Optional[str] = None,
322
+ reports_filter: str = "",
323
+ sources_filter: str = "",
324
+ subtype_filter: str = "",
325
+ year_filter: str = "",
326
+ output_format: str = "structured"
327
+ ):
328
  """
329
+ Unified streaming function supporting both file objects and raw content.
330
 
331
  Args:
332
+ query: User query string
333
+ file_upload: File object from Gradio (optional)
334
+ file_content: Raw file bytes (optional, alternative to file_upload)
335
+ filename: Filename for raw content (required if file_content provided)
336
+ output_format: "structured" returns dicts, "gradio" returns accumulated text
337
  """
338
+ # Handle file_upload if provided (Gradio use case)
 
 
339
  if file_upload is not None:
340
  try:
341
  with open(file_upload.name, 'rb') as f:
 
351
  return
352
 
353
  start_time = datetime.now()
354
+ session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}"
355
 
356
  try:
357
+ # Build initial state
358
  initial_state = {
359
  "query": query,
360
  "context": "",
 
376
  }
377
  }
378
 
379
+ # Execute workflow nodes
380
+ state = merge_state(initial_state, detect_file_type_node(initial_state))
381
+ state = merge_state(state, ingest_node(state))
 
 
382
 
383
+ workflow_type = route_workflow(state)
 
384
 
385
  if workflow_type == "geojson_direct":
386
+ final_state = geojson_direct_result_node(state)
 
387
  if output_format == "structured":
388
  yield {"type": "data", "content": final_state["result"]}
389
  yield {"type": "end", "content": ""}
390
  else:
391
  yield final_state["result"]
392
  else:
393
+ state = merge_state(state, retrieve_node(state))
 
394
 
 
395
  sources_collected = None
396
  accumulated_response = "" if output_format == "gradio" else None
397
 
398
+ async for partial_state in generate_node_streaming(state):
 
399
  if "result" in partial_state:
400
  if output_format == "structured":
401
  yield {"type": "data", "content": partial_state["result"]}
402
  else:
 
403
  accumulated_response += partial_state["result"]
404
  yield accumulated_response
405
 
 
406
  if "sources" in partial_state:
407
  sources_collected = partial_state["sources"]
408
 
409
+ # Format and yield sources
410
  if sources_collected:
411
  if output_format == "structured":
412
  yield {"type": "sources", "content": sources_collected}
413
  else:
 
414
  sources_text = "\n\n**Sources:**\n"
415
  for i, source in enumerate(sources_collected, 1):
416
  if isinstance(source, dict):
app/utils.py CHANGED
@@ -2,28 +2,24 @@ import configparser
2
  import logging
3
  import os
4
  import ast
5
- import re
6
  from dotenv import load_dotenv
7
  from typing import Optional, Dict, Any, List
 
8
 
9
- # Local .env file
10
  load_dotenv()
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  def getconfig(configfile_path: str):
15
- """
16
- Read the config file
17
- Params
18
- ----------------
19
- configfile_path: file path of .cfg file
20
- """
21
  config = configparser.ConfigParser()
22
  try:
23
  config.read_file(open(configfile_path))
24
  return config
25
  except:
26
  logging.warning("config file not found")
 
27
 
28
 
29
  def get_auth(provider: str) -> dict:
@@ -33,7 +29,7 @@ def get_auth(provider: str) -> dict:
33
  "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
34
  }
35
 
36
- provider = provider.lower() # Normalize to lowercase
37
 
38
  if provider not in auth_configs:
39
  raise ValueError(f"Unsupported provider: {provider}")
@@ -42,24 +38,22 @@ def get_auth(provider: str) -> dict:
42
  api_key = auth_config.get("api_key")
43
 
44
  if not api_key:
45
- logging.warning(f"No API key found for provider '{provider}'. Please set the appropriate environment variable.")
46
  auth_config["api_key"] = None
47
 
48
  return auth_config
49
 
50
- # File type detection
51
  def detect_file_type(filename: str, file_content: bytes = None) -> str:
52
  """Detect file type based on extension and content"""
53
  if not filename:
54
  return "unknown"
55
 
56
- # Get file extension
57
  _, ext = os.path.splitext(filename.lower())
58
 
59
- # Define file type mappings
60
  file_type_mappings = {
61
  '.geojson': 'geojson',
62
- '.json': 'json', # Could be geojson, will check content
63
  '.pdf': 'text',
64
  '.docx': 'text',
65
  '.doc': 'text',
@@ -75,29 +69,28 @@ def detect_file_type(filename: str, file_content: bytes = None) -> str:
75
  # For JSON files, check if it's actually GeoJSON
76
  if detected_type == 'json' and file_content:
77
  try:
78
- import json
79
  content_str = file_content.decode('utf-8')
80
  data = json.loads(content_str)
81
- # Check if it has GeoJSON structure
82
- if isinstance(data, dict) and ('type' in data and data.get('type') == 'FeatureCollection'):
83
  detected_type = 'geojson'
84
- elif isinstance(data, dict) and ('type' in data and data.get('type') in ['Feature', 'Point', 'LineString', 'Polygon', 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection']):
 
 
 
85
  detected_type = 'geojson'
86
  except:
87
- pass # Keep as json if parsing fails
88
 
89
  logger.info(f"Detected file type: {detected_type} for file: {filename}")
90
  return detected_type
91
 
92
- # Helper function to convert retrieval context to expected format
93
  def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
94
  """Convert string context to list format expected by generator"""
95
  try:
96
- # Try to parse as list first
97
  if context.startswith('['):
98
  return ast.literal_eval(context)
99
  else:
100
- # If it's a string, wrap it in a simple format
101
  return [{
102
  "answer": context,
103
  "answer_metadata": {
@@ -108,7 +101,6 @@ def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
108
  }
109
  }]
110
  except:
111
- # Fallback: simple string wrapping
112
  return [{
113
  "answer": context,
114
  "answer_metadata": {
@@ -117,4 +109,9 @@ def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
117
  "year": "Unknown",
118
  "source": "Retriever"
119
  }
120
- }]
 
 
 
 
 
 
2
  import logging
3
  import os
4
  import ast
5
+ import json
6
  from dotenv import load_dotenv
7
  from typing import Optional, Dict, Any, List
8
+ from models import GraphState
9
 
 
10
  load_dotenv()
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
  def getconfig(configfile_path: str):
15
+ """Read the config file"""
 
 
 
 
 
16
  config = configparser.ConfigParser()
17
  try:
18
  config.read_file(open(configfile_path))
19
  return config
20
  except:
21
  logging.warning("config file not found")
22
+ return None
23
 
24
 
25
  def get_auth(provider: str) -> dict:
 
29
  "qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
30
  }
31
 
32
+ provider = provider.lower()
33
 
34
  if provider not in auth_configs:
35
  raise ValueError(f"Unsupported provider: {provider}")
 
38
  api_key = auth_config.get("api_key")
39
 
40
  if not api_key:
41
+ logging.warning(f"No API key found for provider '{provider}'")
42
  auth_config["api_key"] = None
43
 
44
  return auth_config
45
 
46
+
47
  def detect_file_type(filename: str, file_content: bytes = None) -> str:
48
  """Detect file type based on extension and content"""
49
  if not filename:
50
  return "unknown"
51
 
 
52
  _, ext = os.path.splitext(filename.lower())
53
 
 
54
  file_type_mappings = {
55
  '.geojson': 'geojson',
56
+ '.json': 'json',
57
  '.pdf': 'text',
58
  '.docx': 'text',
59
  '.doc': 'text',
 
69
  # For JSON files, check if it's actually GeoJSON
70
  if detected_type == 'json' and file_content:
71
  try:
 
72
  content_str = file_content.decode('utf-8')
73
  data = json.loads(content_str)
74
+ if isinstance(data, dict) and data.get('type') == 'FeatureCollection':
 
75
  detected_type = 'geojson'
76
+ elif isinstance(data, dict) and data.get('type') in [
77
+ 'Feature', 'Point', 'LineString', 'Polygon',
78
+ 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection'
79
+ ]:
80
  detected_type = 'geojson'
81
  except:
82
+ pass
83
 
84
  logger.info(f"Detected file type: {detected_type} for file: {filename}")
85
  return detected_type
86
 
87
+
88
  def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
89
  """Convert string context to list format expected by generator"""
90
  try:
 
91
  if context.startswith('['):
92
  return ast.literal_eval(context)
93
  else:
 
94
  return [{
95
  "answer": context,
96
  "answer_metadata": {
 
101
  }
102
  }]
103
  except:
 
104
  return [{
105
  "answer": context,
106
  "answer_metadata": {
 
109
  "year": "Unknown",
110
  "source": "Retriever"
111
  }
112
+ }]
113
+
114
+
115
+ def merge_state(base_state: GraphState, updates: dict) -> GraphState:
116
+ """Helper to merge node updates into base state"""
117
+ return {**base_state, **updates}