Spaces:
Sleeping
Sleeping
| #CHATFED_ORCHESTRATOR | |
| import gradio as gr | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from langserve import add_routes | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import Optional, Dict, Any, List | |
| from typing_extensions import TypedDict | |
| from pydantic import BaseModel | |
| from gradio_client import Client, file | |
| import uvicorn | |
| import os | |
| from datetime import datetime | |
| import logging | |
| from contextlib import asynccontextmanager | |
| import threading | |
| from langchain_core.runnables import RunnableLambda | |
| import tempfile | |
| from utils import getconfig | |
| config = getconfig("params.cfg") | |
| RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space") | |
| GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space") | |
| INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space") | |
| MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS") | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Models | |
| class GraphState(TypedDict): | |
| query: str | |
| context: str | |
| ingestor_context: str | |
| result: str | |
| reports_filter: str | |
| sources_filter: str | |
| subtype_filter: str | |
| year_filter: str | |
| file_content: Optional[bytes] | |
| filename: Optional[str] | |
| metadata: Optional[Dict[str, Any]] | |
| class ChatFedInput(TypedDict): | |
| query: str | |
| reports_filter: Optional[str] | |
| sources_filter: Optional[str] | |
| subtype_filter: Optional[str] | |
| year_filter: Optional[str] | |
| session_id: Optional[str] | |
| user_id: Optional[str] | |
| file_content: Optional[bytes] | |
| filename: Optional[str] | |
| class ChatFedOutput(TypedDict): | |
| result: str | |
| metadata: Dict[str, Any] | |
| class ChatUIInput(BaseModel): | |
| text: str | |
| # Module functions | |
| def ingest_node(state: GraphState) -> GraphState: | |
| """Process file through ingestor if file is provided""" | |
| start_time = datetime.now() | |
| # If no file provided, skip this step | |
| if not state.get("file_content") or not state.get("filename"): | |
| logger.info("No file provided, skipping ingestion") | |
| return {"ingestor_context": "", "metadata": state.get("metadata", {})} | |
| logger.info(f"Ingesting file: {state['filename']}") | |
| try: | |
| client = Client(INGESTOR) | |
| # Create a temporary file to upload | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: | |
| tmp_file.write(state["file_content"]) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting | |
| ingestor_context = client.predict( | |
| file(tmp_file_path), # Use gradio_client.file() to properly format | |
| api_name="/ingest" | |
| ) | |
| logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") | |
| # Handle error cases | |
| if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"): | |
| raise Exception(ingestor_context) | |
| finally: | |
| # Clean up temporary file | |
| os.unlink(tmp_file_path) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, | |
| "ingestion_success": True | |
| }) | |
| return { | |
| "ingestor_context": ingestor_context, | |
| "metadata": metadata | |
| } | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Ingestion failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestion_success": False, | |
| "ingestion_error": str(e) | |
| }) | |
| return {"ingestor_context": "", "metadata": metadata} | |
| try: | |
| client = Client(INGESTOR) | |
| # Create a temporary file to upload | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file: | |
| tmp_file.write(state["file_content"]) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Call the ingestor's ingest endpoint - returns context directly | |
| ingestor_context = client.predict( | |
| file=tmp_file_path, | |
| api_name="/ingest" | |
| ) | |
| logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}") | |
| finally: | |
| # Clean up temporary file | |
| os.unlink(tmp_file_path) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestor_context_length": len(ingestor_context) if ingestor_context else 0, | |
| "ingestion_success": True | |
| }) | |
| return { | |
| "ingestor_context": ingestor_context, | |
| "metadata": metadata | |
| } | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Ingestion failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "ingestion_duration": duration, | |
| "ingestion_success": False, | |
| "ingestion_error": str(e) | |
| }) | |
| return {"ingestor_context": "", "metadata": metadata} | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| start_time = datetime.now() | |
| logger.info(f"Retrieval: {state['query'][:50]}...") | |
| try: | |
| client = Client(RETRIEVER) | |
| context = client.predict( | |
| query=state["query"], | |
| reports_filter=state.get("reports_filter", ""), | |
| sources_filter=state.get("sources_filter", ""), | |
| subtype_filter=state.get("subtype_filter", ""), | |
| year_filter=state.get("year_filter", ""), | |
| api_name="/retrieve" | |
| ) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "retrieval_duration": duration, | |
| "context_length": len(context) if context else 0, | |
| "retrieval_success": True | |
| }) | |
| return {"context": context, "metadata": metadata} | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Retrieval failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "retrieval_duration": duration, | |
| "retrieval_success": False, | |
| "retrieval_error": str(e) | |
| }) | |
| return {"context": "", "metadata": metadata} | |
| def generate_node(state: GraphState) -> GraphState: | |
| start_time = datetime.now() | |
| logger.info(f"Generation: {state['query'][:50]}...") | |
| try: | |
| # Combine retriever context with ingestor context | |
| retrieved_context = state.get("context", "") | |
| ingestor_context = state.get("ingestor_context", "") | |
| # Limit context size to prevent token overflow | |
| MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS")) | |
| combined_context = "" | |
| if ingestor_context and retrieved_context: | |
| # Prioritize ingestor context, truncate if needed | |
| ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context | |
| retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context | |
| combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}" | |
| elif ingestor_context: | |
| ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context | |
| combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}" | |
| elif retrieved_context: | |
| combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context | |
| client = Client(GENERATOR) | |
| result = client.predict( | |
| query=state["query"], | |
| context=combined_context, | |
| api_name="/generate" | |
| ) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "generation_duration": duration, | |
| "result_length": len(result) if result else 0, | |
| "combined_context_length": len(combined_context), | |
| "generation_success": True | |
| }) | |
| return {"result": result, "metadata": metadata} | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Generation failed: {str(e)}") | |
| metadata = state.get("metadata", {}) | |
| metadata.update({ | |
| "generation_duration": duration, | |
| "generation_success": False, | |
| "generation_error": str(e) | |
| }) | |
| return {"result": f"Error: {str(e)}", "metadata": metadata} | |
| # Updated graph with ingest node | |
| workflow = StateGraph(GraphState) | |
| workflow.add_node("ingest", ingest_node) | |
| workflow.add_node("retrieve", retrieve_node) | |
| workflow.add_node("generate", generate_node) | |
| workflow.add_edge(START, "ingest") | |
| workflow.add_edge("ingest", "retrieve") | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("generate", END) | |
| compiled_graph = workflow.compile() | |
| def process_query_core( | |
| query: str, | |
| reports_filter: str = "", | |
| sources_filter: str = "", | |
| subtype_filter: str = "", | |
| year_filter: str = "", | |
| session_id: Optional[str] = None, | |
| user_id: Optional[str] = None, | |
| file_content: Optional[bytes] = None, | |
| filename: Optional[str] = None, | |
| return_metadata: bool = False | |
| ): | |
| start_time = datetime.now() | |
| if not session_id: | |
| session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}" | |
| try: | |
| initial_state = { | |
| "query": query, | |
| "context": "", | |
| "ingestor_context": "", | |
| "result": "", | |
| "reports_filter": reports_filter or "", | |
| "sources_filter": sources_filter or "", | |
| "subtype_filter": subtype_filter or "", | |
| "year_filter": year_filter or "", | |
| "file_content": file_content, | |
| "filename": filename, | |
| "metadata": { | |
| "session_id": session_id, | |
| "user_id": user_id, | |
| "start_time": start_time.isoformat(), | |
| "has_file_attachment": file_content is not None | |
| } | |
| } | |
| final_state = compiled_graph.invoke(initial_state) | |
| total_duration = (datetime.now() - start_time).total_seconds() | |
| final_metadata = final_state.get("metadata", {}) | |
| final_metadata.update({ | |
| "total_duration": total_duration, | |
| "end_time": datetime.now().isoformat(), | |
| "pipeline_success": True | |
| }) | |
| if return_metadata: | |
| return {"result": final_state["result"], "metadata": final_metadata} | |
| else: | |
| return final_state["result"] | |
| except Exception as e: | |
| total_duration = (datetime.now() - start_time).total_seconds() | |
| logger.error(f"Pipeline failed: {str(e)}") | |
| if return_metadata: | |
| error_metadata = { | |
| "session_id": session_id, | |
| "total_duration": total_duration, | |
| "pipeline_success": False, | |
| "error": str(e) | |
| } | |
| return {"result": f"Error: {str(e)}", "metadata": error_metadata} | |
| else: | |
| return f"Error: {str(e)}" | |
| def process_query_gradio(query: str, file_upload, reports_filter: str = "", sources_filter: str = "", | |
| subtype_filter: str = "", year_filter: str = "") -> str: | |
| """Gradio interface function with file upload support""" | |
| file_content = None | |
| filename = None | |
| if file_upload is not None: | |
| try: | |
| with open(file_upload.name, 'rb') as f: | |
| file_content = f.read() | |
| filename = os.path.basename(file_upload.name) | |
| logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes") | |
| except Exception as e: | |
| logger.error(f"Error reading uploaded file: {str(e)}") | |
| return f"Error reading file: {str(e)}" | |
| return process_query_core( | |
| query=query, | |
| reports_filter=reports_filter, | |
| sources_filter=sources_filter, | |
| subtype_filter=subtype_filter, | |
| year_filter=year_filter, | |
| file_content=file_content, | |
| filename=filename, | |
| session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
| return_metadata=False | |
| ) | |
| def chatui_adapter(data) -> str: | |
| try: | |
| # Handle both dict and Pydantic model input | |
| if hasattr(data, 'text'): | |
| text = data.text | |
| elif isinstance(data, dict) and 'text' in data: | |
| text = data['text'] | |
| else: | |
| logger.error(f"Unexpected input structure: {data}") | |
| return "Error: Invalid input format. Expected 'text' field." | |
| result = process_query_core( | |
| query=text, | |
| session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}", | |
| return_metadata=False | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"ChatUI error: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput: | |
| result = process_query_core( | |
| query=input_data["query"], | |
| reports_filter=input_data.get("reports_filter", ""), | |
| sources_filter=input_data.get("sources_filter", ""), | |
| subtype_filter=input_data.get("subtype_filter", ""), | |
| year_filter=input_data.get("year_filter", ""), | |
| session_id=input_data.get("session_id"), | |
| user_id=input_data.get("user_id"), | |
| file_content=input_data.get("file_content"), | |
| filename=input_data.get("filename"), | |
| return_metadata=True | |
| ) | |
| return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
| def create_gradio_interface(): | |
| with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
| gr.Markdown("# ChatFed Orchestrator") | |
| gr.Markdown("Upload documents (PDF/DOCX) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`") | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...") | |
| file_input = gr.File(label="Upload Document (PDF/DOCX)", file_types=[".pdf", ".docx"]) | |
| with gr.Accordion("Filters (Optional)", open=False): | |
| reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports") | |
| sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal") | |
| subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial") | |
| year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Response", lines=15, show_copy_button=True) | |
| submit_btn.click( | |
| fn=process_query_gradio, | |
| inputs=[query_input, file_input, reports_filter_input, sources_filter_input, | |
| subtype_filter_input, year_filter_input], | |
| outputs=output | |
| ) | |
| return demo | |
| async def lifespan(app: FastAPI): | |
| logger.info("ChatFed Orchestrator starting up...") | |
| yield | |
| logger.info("Orchestrator shutting down...") | |
| app = FastAPI( | |
| title="ChatFed Orchestrator", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| docs_url=None, | |
| redoc_url=None | |
| ) | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def root(): | |
| return { | |
| "message": "ChatFed Orchestrator API", | |
| "endpoints": { | |
| "health": "/health", | |
| "chatfed": "/chatfed", | |
| "chatfed-ui-stream": "/chatfed-ui-stream", | |
| "chatfed-with-file": "/chatfed-with-file" | |
| } | |
| } | |
| # Additional endpoint for file uploads via API | |
| async def chatfed_with_file( | |
| query: str = Form(...), | |
| file: Optional[UploadFile] = File(None), | |
| reports_filter: Optional[str] = Form(""), | |
| sources_filter: Optional[str] = Form(""), | |
| subtype_filter: Optional[str] = Form(""), | |
| year_filter: Optional[str] = Form(""), | |
| session_id: Optional[str] = Form(None), | |
| user_id: Optional[str] = Form(None) | |
| ): | |
| """Endpoint for queries with optional file attachments""" | |
| file_content = None | |
| filename = None | |
| if file: | |
| file_content = await file.read() | |
| filename = file.filename | |
| result = process_query_core( | |
| query=query, | |
| reports_filter=reports_filter, | |
| sources_filter=sources_filter, | |
| subtype_filter=subtype_filter, | |
| year_filter=year_filter, | |
| file_content=file_content, | |
| filename=filename, | |
| session_id=session_id, | |
| user_id=user_id, | |
| return_metadata=True | |
| ) | |
| return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
| # Additional endpoint for file uploads via API | |
| async def chatfed_with_file( | |
| query: str = Form(...), | |
| file: Optional[UploadFile] = File(None), | |
| reports_filter: Optional[str] = Form(""), | |
| sources_filter: Optional[str] = Form(""), | |
| subtype_filter: Optional[str] = Form(""), | |
| year_filter: Optional[str] = Form(""), | |
| session_id: Optional[str] = Form(None), | |
| user_id: Optional[str] = Form(None) | |
| ): | |
| """Endpoint for queries with optional file attachments""" | |
| file_content = None | |
| filename = None | |
| if file: | |
| file_content = await file.read() | |
| filename = file.filename | |
| result = process_query_core( | |
| query=query, | |
| reports_filter=reports_filter, | |
| sources_filter=sources_filter, | |
| subtype_filter=subtype_filter, | |
| year_filter=year_filter, | |
| file_content=file_content, | |
| filename=filename, | |
| session_id=session_id, | |
| user_id=user_id, | |
| return_metadata=True | |
| ) | |
| return ChatFedOutput(result=result["result"], metadata=result["metadata"]) | |
| # LangServe routes (these are the main endpoints) | |
| add_routes( | |
| app, | |
| RunnableLambda(process_query_langserve), | |
| path="/chatfed", | |
| input_type=ChatFedInput, | |
| output_type=ChatFedOutput | |
| ) | |
| add_routes( | |
| app, | |
| RunnableLambda(chatui_adapter), | |
| path="/chatfed-ui-stream", | |
| input_type=ChatUIInput, | |
| output_type=str, | |
| enable_feedback_endpoint=True, | |
| enable_public_trace_link_endpoint=True, | |
| ) | |
| def run_gradio_server(): | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7861, | |
| mcp_server=True, | |
| show_error=True, | |
| share=False, | |
| quiet=True | |
| ) | |
| if __name__ == "__main__": | |
| gradio_thread = threading.Thread(target=run_gradio_server, daemon=True) | |
| gradio_thread.start() | |
| logger.info("Gradio MCP server started on port 7861") | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = int(os.getenv("PORT", "7860")) | |
| logger.info(f"Starting FastAPI server on {host}:{port}") | |
| uvicorn.run(app, host=host, port=port, log_level="info", access_log=True) |