File size: 10,945 Bytes
3d98931
 
 
4d95fe3
3d98931
 
 
e6aebd3
3d98931
 
 
e583f25
47ebcc2
3d98931
 
 
47ebcc2
3d98931
ac9b41a
3d98931
 
 
9e430dc
3d98931
 
ac9b41a
 
3d98931
 
e583f25
3d98931
 
 
 
 
 
 
2ade6fb
3d98931
 
 
 
 
 
 
 
 
 
e6aebd3
 
ac9b41a
3d98931
 
4d95fe3
3d98931
 
 
 
 
e583f25
3d98931
 
 
 
 
 
 
 
 
e583f25
3d98931
 
 
 
 
 
 
e583f25
3d98931
e583f25
 
3d98931
 
 
 
 
 
 
 
 
 
d357a83
3d98931
 
 
1d98f09
d357a83
3d98931
 
 
 
 
 
d357a83
3d98931
 
 
 
 
 
 
d357a83
3d98931
d357a83
 
3d98931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aedcb69
3d98931
 
 
 
 
 
 
 
 
 
 
 
 
 
e583f25
3d98931
 
aedcb69
3d98931
 
 
 
 
2ade6fb
aedcb69
3d98931
 
f4d2586
3d98931
e583f25
2ade6fb
3d98931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e583f25
3d98931
e583f25
 
3d98931
 
 
 
e583f25
3d98931
 
e583f25
3d98931
 
 
 
 
e583f25
 
3d98931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d95fe3
3d98931
e583f25
3d98931
 
 
e583f25
 
3d98931
 
 
 
 
 
 
 
 
 
e583f25
3d98931
 
 
 
 
e583f25
 
 
3d98931
 
 
 
 
e583f25
 
3d98931
 
 
 
 
e583f25
 
aedcb69
 
3d98931
aedcb69
 
 
 
3d98931
aedcb69
e583f25
3d98931
 
aedcb69
 
 
3d98931
 
 
 
 
 
 
 
aedcb69
3d98931
 
 
 
 
 
 
 
 
f8b3c54
3d98931
 
 
 
 
 
 
 
 
 
a326770
e6aebd3
3d98931
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Gradio UI not currenlty working. 
import gradio as gr
from fastapi import FastAPI
from langserve import add_routes
from langgraph.graph import StateGraph, START, END
from typing import Optional, Dict, Any
from typing_extensions import TypedDict
from pydantic import BaseModel
from gradio_client import Client
import uvicorn
import os
from datetime import datetime
import logging
from contextlib import asynccontextmanager
import threading
from langchain_core.runnables import RunnableLambda

from utils import getconfig

config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER")
GENERATOR = config.get("generator", "GENERATOR")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# Models
class GraphState(TypedDict):
    query: str
    context: str
    result: str
    reports_filter: str
    sources_filter: str
    subtype_filter: str
    year_filter: str
    metadata: Optional[Dict[str, Any]]

class ChatFedInput(TypedDict):
    query: str
    reports_filter: Optional[str]
    sources_filter: Optional[str] 
    subtype_filter: Optional[str]
    year_filter: Optional[str]
    session_id: Optional[str]
    user_id: Optional[str]

class ChatFedOutput(TypedDict):
    result: str
    metadata: Dict[str, Any]

class ChatUIInput(BaseModel):
    text: str

# Module functions
def retrieve_node(state: GraphState) -> GraphState:
    start_time = datetime.now()
    logger.info(f"Retrieval: {state['query'][:50]}...")
    
    try:
        client = Client(RETRIEVER)
        context = client.predict(
            query=state["query"],
            reports_filter=state.get("reports_filter", ""),
            sources_filter=state.get("sources_filter", ""),
            subtype_filter=state.get("subtype_filter", ""),
            year_filter=state.get("year_filter", ""),
            api_name="/retrieve"
        )
        
        duration = (datetime.now() - start_time).total_seconds()
        metadata = state.get("metadata", {})
        metadata.update({
            "retrieval_duration": duration,
            "context_length": len(context) if context else 0,
            "retrieval_success": True
        })
        
        return {"context": context, "metadata": metadata}
        
    except Exception as e:
        duration = (datetime.now() - start_time).total_seconds()
        logger.error(f"Retrieval failed: {str(e)}")
        
        metadata = state.get("metadata", {})
        metadata.update({
            "retrieval_duration": duration,
            "retrieval_success": False,
            "retrieval_error": str(e)
        })
        return {"context": "", "metadata": metadata}

def generate_node(state: GraphState) -> GraphState:
    start_time = datetime.now()
    logger.info(f"Generation: {state['query'][:50]}...")
    
    try:
        client = Client(GENERATOR)
        result = client.predict(
            query=state["query"],
            context=state["context"],
            api_name="/generate"
        )
        
        duration = (datetime.now() - start_time).total_seconds()
        metadata = state.get("metadata", {})
        metadata.update({
            "generation_duration": duration,
            "result_length": len(result) if result else 0,
            "generation_success": True
        })
        
        return {"result": result, "metadata": metadata}
        
    except Exception as e:
        duration = (datetime.now() - start_time).total_seconds()
        logger.error(f"Generation failed: {str(e)}")
        
        metadata = state.get("metadata", {})
        metadata.update({
            "generation_duration": duration,
            "generation_success": False,
            "generation_error": str(e)
        })
        return {"result": f"Error: {str(e)}", "metadata": metadata}

# start the graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generate_node)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
compiled_graph = workflow.compile()

def process_query_core(
    query: str,
    reports_filter: str = "",
    sources_filter: str = "",
    subtype_filter: str = "",
    year_filter: str = "",
    session_id: Optional[str] = None,
    user_id: Optional[str] = None,
    return_metadata: bool = False
):
    start_time = datetime.now()
    if not session_id:
        session_id = f"session_{start_time.strftime('%Y%m%d_%H%M%S')}"
    
    try:
        initial_state = {
            "query": query,
            "context": "",
            "result": "",
            "reports_filter": reports_filter or "",
            "sources_filter": sources_filter or "",
            "subtype_filter": subtype_filter or "",
            "year_filter": year_filter or "",
            "metadata": {
                "session_id": session_id,
                "user_id": user_id,
                "start_time": start_time.isoformat()
            }
        }
        
        final_state = compiled_graph.invoke(initial_state)
        total_duration = (datetime.now() - start_time).total_seconds()
        
        final_metadata = final_state.get("metadata", {})
        final_metadata.update({
            "total_duration": total_duration,
            "end_time": datetime.now().isoformat(),
            "pipeline_success": True
        })
        
        if return_metadata:
            return {"result": final_state["result"], "metadata": final_metadata}
        else:
            return final_state["result"]
        
    except Exception as e:
        total_duration = (datetime.now() - start_time).total_seconds()
        logger.error(f"Pipeline failed: {str(e)}")
        
        if return_metadata:
            error_metadata = {
                "session_id": session_id,
                "total_duration": total_duration,
                "pipeline_success": False,
                "error": str(e)
            }
            return {"result": f"Error: {str(e)}", "metadata": error_metadata}
        else:
            return f"Error: {str(e)}"

def process_query_gradio(query: str, reports_filter: str = "", sources_filter: str = "", 
                        subtype_filter: str = "", year_filter: str = "") -> str:
    return process_query_core(
        query=query,
        reports_filter=reports_filter,
        sources_filter=sources_filter,
        subtype_filter=subtype_filter,
        year_filter=year_filter,
        session_id=f"gradio_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
        return_metadata=False
    )

def chatui_adapter(data) -> str:
    try:
        # Handle both dict and Pydantic model input
        if hasattr(data, 'text'):
            text = data.text
        elif isinstance(data, dict) and 'text' in data:
            text = data['text']
        else:
            logger.error(f"Unexpected input structure: {data}")
            return "Error: Invalid input format. Expected 'text' field."
        
        result = process_query_core(
            query=text,
            session_id=f"chatui_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            return_metadata=False
        )
        return result
    except Exception as e:
        logger.error(f"ChatUI error: {str(e)}")
        return f"Error: {str(e)}"

def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
    result = process_query_core(
        query=input_data["query"],
        reports_filter=input_data.get("reports_filter", ""),
        sources_filter=input_data.get("sources_filter", ""),
        subtype_filter=input_data.get("subtype_filter", ""),
        year_filter=input_data.get("year_filter", ""),
        session_id=input_data.get("session_id"),
        user_id=input_data.get("user_id"),
        return_metadata=True
    )
    return ChatFedOutput(result=result["result"], metadata=result["metadata"])

# This is not working currently... Problematic because HF doesn't allow > 1 port open at the same time
def create_gradio_interface():
    with gr.Blocks(title="ChatFed Orchestrator") as demo:
        gr.Markdown("# ChatFed Orchestrator")
        gr.Markdown("MCP endpoints available at `/gradio_api/mcp/sse`")
        
        with gr.Row():
            with gr.Column():
                query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
                reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
                sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
                subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
                year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
                submit_btn = gr.Button("Submit", variant="primary")
            
            with gr.Column():
                output = gr.Textbox(label="Response", lines=10)
        
        submit_btn.click(
            fn=process_query_gradio,
            inputs=[query_input, reports_filter_input, sources_filter_input, subtype_filter_input, year_filter_input],
            outputs=output
        )
    
    return demo

@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("ChatFed Orchestrator starting up...")
    yield
    logger.info("Orchestrator shutting down...")

app = FastAPI(
    title="ChatFed Orchestrator",
    version="1.0.0",
    lifespan=lifespan,
    docs_url=None,
    redoc_url=None
)

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/")
async def root():
    return {
        "message": "ChatFed Orchestrator API",
        "endpoints": {
            "health": "/health",
            "chatfed": "/chatfed",
            "chatfed-ui-stream": "/chatfed-ui-stream"
        }
    }

# LangServe routes (these are the main endpoints)
add_routes(
    app,
    RunnableLambda(process_query_langserve),
    path="/chatfed",
    input_type=ChatFedInput,
    output_type=ChatFedOutput
)

add_routes(
    app,
    RunnableLambda(chatui_adapter),
    path="/chatfed-ui-stream",
    input_type=ChatUIInput,
    output_type=str,
    enable_feedback_endpoint=True,
    enable_public_trace_link_endpoint=True,
)

def run_gradio_server():
    demo = create_gradio_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=7861,
        mcp_server=True,   
        show_error=True,
        share=False,
        quiet=True
    )

if __name__ == "__main__":
    gradio_thread = threading.Thread(target=run_gradio_server, daemon=True)
    gradio_thread.start()
    logger.info("Gradio MCP server started on port 7861")
    
    host = os.getenv("HOST", "0.0.0.0")
    port = int(os.getenv("PORT", "7860"))
    
    logger.info(f"Starting FastAPI server on {host}:{port}")
    
    uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)