#CHATFED_ORCHESTRATOR import gradio as gr from fastapi import FastAPI, UploadFile, File, Form from langserve import add_routes from langgraph.graph import StateGraph, START, END from typing import Optional, Dict, Any, List from typing_extensions import TypedDict from pydantic import BaseModel from gradio_client import Client, file import uvicorn import os from datetime import datetime import logging from contextlib import asynccontextmanager import threading from langchain_core.runnables import RunnableLambda import tempfile from utils import getconfig config = getconfig("params.cfg") RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space") GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space") INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space") MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS") logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Models class GraphState(TypedDict): query: str context: str ingestor_context: str result: str reports_filter: str sources_filter: str subtype_filter: str year_filter: str file_content: Optional[bytes] filename: Optional[str] metadata: Optional[Dict[str, Any]] class ChatFedInput(TypedDict): query: str reports_filter: Optional[str] sources_filter: Optional[str] subtype_filter: Optional[str] year_filter: Optional[str] session_id: Optional[str] user_id: Optional[str] file_content: Optional[bytes] filename: Optional[str] class ChatFedOutput(TypedDict): result: str metadata: Dict[str, Any] class ChatUIInput(BaseModel): text: str # Module functions def ingest_node(state: GraphState) -> GraphState: """Process file through ingestor if file is provided""" start_time = datetime.now() # If no file provided, skip this step if not state.get("file_content") or not state.get("filename"): logger.info("No file provided, skipping ingestion") return {"ingestor_context": "", "metadata": state.get("metadata", {})} logger.info(f"Ingesting file: {state['filename']}") try: client = Client(INGESTOR) # Create a temporary file to upload with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: tmp_file.write(state["file_content"]) tmp_file_path = tmp_file.name try: # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting ingestor_context = client.predict( file(tmp_file_path), # Use gradio_client.file() to properly format api_name="/ingest" ) logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") # Handle error cases if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"): raise Exception(ingestor_context) finally: # Clean up temporary file os.unlink(tmp_file_path) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, "ingestion_success": True }) return { "ingestor_context": ingestor_context, "metadata": metadata } except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Ingestion failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestion_success": False, "ingestion_error": str(e) }) return {"ingestor_context": "", "metadata": metadata} try: client = Client(INGESTOR) # Create a temporary file to upload with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: tmp_file.write(state["file_content"]) tmp_file_path = tmp_file.name try: # Call the ingestor's ingest endpoint - returns context directly ingestor_context = client.predict( file=tmp_file_path, api_name="/ingest" ) logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") finally: # Clean up temporary file os.unlink(tmp_file_path) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, "ingestion_success": True }) return { "ingestor_context": ingestor_context, "metadata": metadata } except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Ingestion failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "ingestion_duration": duration, "ingestion_success": False, "ingestion_error": str(e) }) return {"ingestor_context": "", "metadata": metadata} def retrieve_node(state: GraphState) -> GraphState: start_time = datetime.now() logger.info(f"Retrieval: {state['query'][:50]}...") try: client = Client(RETRIEVER) context = client.predict( query=state["query"], reports_filter=state.get("reports_filter", ""), sources_filter=state.get("sources_filter", ""), subtype_filter=state.get("subtype_filter", ""), year_filter=state.get("year_filter", ""), api_name="/retrieve" ) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "context_length": len(context) if context else 0, "retrieval_success": True }) return {"context": context, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Retrieval failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "retrieval_duration": duration, "retrieval_success": False, "retrieval_error": str(e) }) return {"context": "", "metadata": metadata} def generate_node(state: GraphState) -> GraphState: start_time = datetime.now() logger.info(f"Generation: {state['query'][:50]}...") try: # Combine retriever context with ingestor context retrieved_context = state.get("context", "") ingestor_context = state.get("ingestor_context", "") # Limit context size to prevent token overflow MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS")) combined_context = "" if ingestor_context and retrieved_context: # Prioritize ingestor context, truncate if needed ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}" elif ingestor_context: ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}" elif retrieved_context: combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context client = Client(GENERATOR) result = client.predict( query=state["query"], context=combined_context, api_name="/generate" ) duration = (datetime.now() - start_time).total_seconds() metadata = state.get("metadata", {}) metadata.update({ "generation_duration": duration, "result_length": len(result) if result else 0, "combined_context_length": len(combined_context), "generation_success": True }) return {"result": result, "metadata": metadata} except Exception as e: duration = (datetime.now() - start_time).total_seconds() logger.error(f"Generation failed: {str(e)}") metadata = state.get("metadata", {}) metadata.update({ "generation_duration": duration, "generation_success": False, "generation_error": str(e) }) return {"result": f"Error: {str(e)}", "metadata": metadata} # Updated graph with ingest node workflow = StateGraph(GraphState) workflow.add_node("ingest", ingest_node) workflow.add_node("retrieve", retrieve_node) workflow.add_node("generate", generate_node) workflow.add_edge(START, "ingest") workflow.add_edge("ingest", "retrieve") workflow.add_edge("retrieve", "generate") workflow.add_edge("generate", END) compiled_graph = workflow.compile() def process_query_core( query: str, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "", session_id: Optional[str] = None, user_id: Optional[str] = None, file_content: Optional[bytes] = None, filename: Optional[str] = None, return_metadata: bool = False ): start_time = datetime.now() if not session_id: session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}" try: initial_state = { "query": query, "context": "", "ingestor_context": "", "result": "", "reports_filter": reports_filter or "", "sources_filter": sources_filter or "", "subtype_filter": subtype_filter or "", "year_filter": year_filter or "", "file_content": file_content, "filename": filename, "metadata": { "session_id": session_id, "user_id": user_id, "start_time": start_time.isoformat(), "has_file_attachment": file_content is not None } } final_state = compiled_graph.invoke(initial_state) total_duration = (datetime.now() - start_time).total_seconds() final_metadata = final_state.get("metadata", {}) final_metadata.update({ "total_duration": total_duration, "end_time": datetime.now().isoformat(), "pipeline_success": True }) if return_metadata: return {"result": final_state["result"], "metadata": final_metadata} else: return final_state["result"] except Exception as e: total_duration = (datetime.now() - start_time).total_seconds() logger.error(f"Pipeline failed: {str(e)}") if return_metadata: error_metadata = { "session_id": session_id, "total_duration": total_duration, "pipeline_success": False, "error": str(e) } return {"result": f"Error: {str(e)}", "metadata": error_metadata} else: return f"Error: {str(e)}" def process_query_gradio(query: str, file_upload, reports_filter: str = "", sources_filter: str = "", subtype_filter: str = "", year_filter: str = "") -> str: """Gradio interface function with file upload support""" file_content = None filename = None if file_upload is not None: try: with open(file_upload.name, 'rb') as f: file_content = f.read() filename = os.path.basename(file_upload.name) logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes") except Exception as e: logger.error(f"Error reading uploaded file: {str(e)}") return f"Error reading file: {str(e)}" return process_query_core( query=query, reports_filter=reports_filter, sources_filter=sources_filter, subtype_filter=subtype_filter, year_filter=year_filter, file_content=file_content, filename=filename, session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}", return_metadata=False ) def chatui_adapter(data) -> str: try: # Handle both dict and Pydantic model input if hasattr(data, 'text'): text = data.text elif isinstance(data, dict) and 'text' in data: text = data['text'] else: logger.error(f"Unexpected input structure: {data}") return "Error: Invalid input format. Expected 'text' field." result = process_query_core( query=text, session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", return_metadata=False ) return result except Exception as e: logger.error(f"ChatUI error: {str(e)}") return f"Error: {str(e)}" def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput: result = process_query_core( query=input_data["query"], reports_filter=input_data.get("reports_filter", ""), sources_filter=input_data.get("sources_filter", ""), subtype_filter=input_data.get("subtype_filter", ""), year_filter=input_data.get("year_filter", ""), session_id=input_data.get("session_id"), user_id=input_data.get("user_id"), file_content=input_data.get("file_content"), filename=input_data.get("filename"), return_metadata=True ) return ChatFedOutput(result=result["result"], metadata=result["metadata"]) def create_gradio_interface(): with gr.Blocks(title="ChatFed Orchestrator") as demo: gr.Markdown("# ChatFed Orchestrator") gr.Markdown("Upload documents (PDF/DOCX) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`") with gr.Row(): with gr.Column(): query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") file_input = gr.File(label="Upload Document (PDF/DOCX)", file_types=[".pdf", ".docx"]) with gr.Accordion("Filters (Optional)", open=False): reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): output = gr.Textbox(label="Response", lines=15, show_copy_button=True) submit_btn.click( fn=process_query_gradio, inputs=[query_input, file_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input], outputs=output ) return demo @asynccontextmanager async def lifespan(app: FastAPI): logger.info("ChatFed Orchestrator starting up...") yield logger.info("Orchestrator shutting down...") app = FastAPI( title="ChatFed Orchestrator", version="1.0.0", lifespan=lifespan, docs_url=None, redoc_url=None ) @app.get("/health") async def health_check(): return {"status": "healthy"} @app.get("/") async def root(): return { "message": "ChatFed Orchestrator API", "endpoints": { "health": "/health", "chatfed": "/chatfed", "chatfed-ui-stream": "/chatfed-ui-stream", "chatfed-with-file": "/chatfed-with-file" } } # Additional endpoint for file uploads via API @app.post("/chatfed-with-file") async def chatfed_with_file( query: str = Form(...), file: Optional[UploadFile] = File(None), reports_filter: Optional[str] = Form(""), sources_filter: Optional[str] = Form(""), subtype_filter: Optional[str] = Form(""), year_filter: Optional[str] = Form(""), session_id: Optional[str] = Form(None), user_id: Optional[str] = Form(None) ): """Endpoint for queries with optional file attachments""" file_content = None filename = None if file: file_content = await file.read() filename = file.filename result = process_query_core( query=query, reports_filter=reports_filter, sources_filter=sources_filter, subtype_filter=subtype_filter, year_filter=year_filter, file_content=file_content, filename=filename, session_id=session_id, user_id=user_id, return_metadata=True ) return ChatFedOutput(result=result["result"], metadata=result["metadata"]) # Additional endpoint for file uploads via API @app.post("/chatfed-with-file") async def chatfed_with_file( query: str = Form(...), file: Optional[UploadFile] = File(None), reports_filter: Optional[str] = Form(""), sources_filter: Optional[str] = Form(""), subtype_filter: Optional[str] = Form(""), year_filter: Optional[str] = Form(""), session_id: Optional[str] = Form(None), user_id: Optional[str] = Form(None) ): """Endpoint for queries with optional file attachments""" file_content = None filename = None if file: file_content = await file.read() filename = file.filename result = process_query_core( query=query, reports_filter=reports_filter, sources_filter=sources_filter, subtype_filter=subtype_filter, year_filter=year_filter, file_content=file_content, filename=filename, session_id=session_id, user_id=user_id, return_metadata=True ) return ChatFedOutput(result=result["result"], metadata=result["metadata"]) # LangServe routes (these are the main endpoints) add_routes( app, RunnableLambda(process_query_langserve), path="/chatfed", input_type=ChatFedInput, output_type=ChatFedOutput ) add_routes( app, RunnableLambda(chatui_adapter), path="/chatfed-ui-stream", input_type=ChatUIInput, output_type=str, enable_feedback_endpoint=True, enable_public_trace_link_endpoint=True, ) def run_gradio_server(): demo = create_gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7861, mcp_server=True, show_error=True, share=False, quiet=True ) if __name__ == "__main__": gradio_thread = threading.Thread(target=run_gradio_server, daemon=True) gradio_thread.start() logger.info("Gradio MCP server started on port 7861") host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) logger.info(f"Starting FastAPI server on {host}:{port}") uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)