Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, List, Optional, Union, Any | |
from pydantic import BaseModel, Field | |
from datetime import datetime | |
import logging | |
import json | |
import os | |
from dotenv import load_dotenv | |
from dify_client_python.dify_client import models | |
from sse_starlette.sse import EventSourceResponse | |
import httpx | |
from json_parser import SSEParser | |
from logger_config import setup_logger | |
from fastapi.responses import StreamingResponse | |
from fastapi.responses import JSONResponse | |
from response_formatter import ResponseFormatter | |
import traceback | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class AgentOutput(BaseModel): | |
"""Structured output from agent processing""" | |
thought_content: str | |
observation: Optional[str] | |
tool_outputs: List[Dict] | |
citations: List[Dict] | |
metadata: Dict | |
raw_response: str | |
class AgentRequest(BaseModel): | |
"""Enhanced request model with additional parameters""" | |
query: str | |
conversation_id: Optional[str] = None | |
stream: bool = True | |
inputs: Dict = {} | |
files: List = [] | |
user: str = "default_user" | |
response_mode: str = "streaming" | |
class AgentProcessor: | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.api_base = "https://rag-engine.go-yamamoto.com/v1" | |
self.formatter = ResponseFormatter() | |
self.client = httpx.AsyncClient(timeout=60.0) | |
self.logger = setup_logger("agent_processor") | |
async def log_request_details( | |
self, | |
request: AgentRequest, | |
start_time: datetime | |
) -> None: | |
"""Log detailed request information""" | |
self.logger.debug( | |
"Request details: \n" | |
f"Query: {request.query}\n" | |
f"User: {request.user}\n" | |
f"Conversation ID: {request.conversation_id}\n" | |
f"Stream mode: {request.stream}\n" | |
f"Start time: {start_time}\n" | |
f"Inputs: {request.inputs}\n" | |
f"Files: {len(request.files)} files attached" | |
) | |
async def log_error( | |
self, | |
error: Exception, | |
context: Optional[Dict] = None | |
) -> None: | |
"""Log detailed error information""" | |
error_msg = ( | |
f"Error type: {type(error).__name__}\n" | |
f"Error message: {str(error)}\n" | |
f"Stack trace:\n{traceback.format_exc()}\n" | |
) | |
if context: | |
error_msg += f"Context:\n{json.dumps(context, indent=2)}" | |
self.logger.error(error_msg) | |
async def cleanup(self): | |
"""Cleanup method to properly close client""" | |
await self.client.aclose() | |
async def process_stream(self, request: AgentRequest): | |
start_time = datetime.now() | |
await self.log_request_details(request, start_time) | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json", | |
"Accept": "text/event-stream" | |
} | |
chat_request = { | |
"query": request.query, | |
"inputs": request.inputs, | |
"response_mode": "streaming" if request.stream else "blocking", | |
"user": request.user, | |
"conversation_id": request.conversation_id, | |
"files": request.files | |
} | |
async def event_generator(): | |
parser = SSEParser() | |
citations = [] | |
metadata = {} | |
try: | |
async with self.client.stream( | |
"POST", | |
f"{self.api_base}/chat-messages", | |
headers=headers, | |
json=chat_request | |
) as response: | |
self.logger.debug( | |
f"Stream connection established\n" | |
f"Status: {response.status_code}\n" | |
f"Headers: {dict(response.headers)}" | |
) | |
buffer = "" | |
async for line in response.aiter_lines(): | |
if not line.strip(): | |
continue | |
self.logger.debug(f"Raw SSE line: {line}") | |
if "data:" in line: | |
try: | |
data = line.split("data:", 1)[1].strip() | |
parsed = json.loads(data) | |
if parsed.get("event") == "message_end": | |
citations = parsed.get("retriever_resources", []) | |
metadata = parsed.get("metadata", {}) | |
self.logger.debug( | |
f"Message end event:\n" | |
f"Citations: {citations}\n" | |
f"Metadata: {metadata}" | |
) | |
formatted = self.format_terminal_output( | |
parsed, | |
citations=citations, | |
metadata=metadata | |
) | |
if formatted: | |
self.logger.info(formatted) | |
except Exception as e: | |
await self.log_error( | |
e, | |
{"line": line, "event": "parse_data"} | |
) | |
buffer += line + "\n" | |
if line.startswith("data:") or buffer.strip().endswith("}"): | |
try: | |
processed_response = parser.parse_sse_event(buffer) | |
if processed_response and isinstance(processed_response, dict): | |
cleaned_response = self.clean_response(processed_response) | |
if cleaned_response: | |
xml_content = cleaned_response.get("content", "") | |
yield f"data: {xml_content}\n\n" | |
except Exception as parse_error: | |
await self.log_error( | |
parse_error, | |
{"buffer": buffer, "event": "process_buffer"} | |
) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>{str(parse_error)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
finally: | |
buffer = "" | |
except httpx.ConnectError as e: | |
await self.log_error(e, {"event": "connection_error"}) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>Connection error: {str(e)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
except Exception as e: | |
await self.log_error(e, {"event": "stream_error"}) | |
error_xml = ( | |
f"<agent_response>" | |
f"<error>Streaming error: {str(e)}</error>" | |
f"</agent_response>" | |
) | |
yield f"data: {error_xml}\n\n" | |
finally: | |
end_time = datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
self.logger.info(f"Request completed in {duration:.2f} seconds") | |
return StreamingResponse( | |
event_generator(), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no", | |
"Access-Control-Allow-Origin": "*" | |
} | |
) | |
def format_terminal_output( | |
self, | |
response: Dict, | |
citations: List[Dict] = None, | |
metadata: Dict = None | |
) -> Optional[str]: | |
"""Format response for terminal output""" | |
event_type = response.get("event") | |
if event_type == "agent_thought": | |
thought = response.get("thought", "") | |
observation = response.get("observation", "") | |
terminal_output, _ = self.formatter.format_thought( | |
thought, | |
observation, | |
citations=citations, | |
metadata=metadata | |
) | |
return terminal_output | |
elif event_type == "agent_message": | |
message = response.get("answer", "") | |
terminal_output, _ = self.formatter.format_message(message) | |
return terminal_output | |
elif event_type == "error": | |
error = response.get("error", "Unknown error") | |
terminal_output, _ = self.formatter.format_error(error) | |
return terminal_output | |
return None | |
def clean_response(self, response: Dict) -> Optional[Dict]: | |
"""Clean and transform the response for frontend consumption""" | |
try: | |
event_type = response.get("event") | |
if not event_type: | |
return None | |
# Handle different event types | |
if event_type == "agent_thought": | |
thought = response.get("thought", "") | |
observation = response.get("observation", "") | |
_, xml_output = self.formatter.format_thought(thought, observation) | |
return { | |
"type": "thought", | |
"content": xml_output | |
} | |
elif event_type == "agent_message": | |
message = response.get("answer", "") | |
_, xml_output = self.formatter.format_message(message) | |
return { | |
"type": "message", | |
"content": xml_output | |
} | |
elif event_type == "error": | |
error = response.get("error", "Unknown error") | |
_, xml_output = self.formatter.format_error(error) | |
return { | |
"type": "error", | |
"content": xml_output | |
} | |
return None | |
except Exception as e: | |
logger.error(f"Error cleaning response: {str(e)}") | |
return None | |
# Initialize FastAPI app | |
app = FastAPI() | |
agent_processor = None | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def startup_event(): | |
global agent_processor | |
api_key = os.getenv("DIFY_API_KEY", "app-kVHTrZzEmFXEBfyXOi4rro7M") | |
agent_processor = AgentProcessor(api_key=api_key) | |
async def shutdown_event(): | |
global agent_processor | |
if agent_processor: | |
await agent_processor.cleanup() | |
async def process_agent_request(request: AgentRequest): | |
try: | |
logger.info(f"Processing agent request: {request.query}") | |
return await agent_processor.process_stream(request) | |
except Exception as e: | |
logger.error(f"Error in agent request processing: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def error_handling_middleware(request: Request, call_next): | |
try: | |
response = await call_next(request) | |
return response | |
except Exception as e: | |
logger.error(f"Unhandled error: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": "Internal server error occurred"} | |
) | |
# Add host and port parameters to the launch | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"api:app", | |
host="0.0.0.0", | |
port=8224, | |
reload=True | |
) |