Spaces:
Running
Running
from fastapi import APIRouter, status, Depends, BackgroundTasks, HTTPException | |
from fastapi.responses import JSONResponse | |
from src.utils.logger import logger | |
from src.agents.role_play.func import create_agents | |
from pydantic import BaseModel, Field | |
from typing import List, Dict, Any, Optional | |
from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id | |
import json | |
import os | |
import uuid | |
from datetime import datetime | |
router = APIRouter(prefix="/ai", tags=["AI"]) | |
class RoleplayRequest(BaseModel): | |
query: str = Field(..., description="User's query for the AI agent") | |
session_id: str = Field( | |
..., description="Session ID for tracking user interactions" | |
) | |
scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") | |
class SessionRequest(BaseModel): | |
session_id: str = Field(..., description="Session ID to perform operations on") | |
class CreateSessionRequest(BaseModel): | |
name: str = Field(..., description="Name for the new session") | |
class UpdateSessionRequest(BaseModel): | |
session_id: str = Field(..., description="Session ID to update") | |
name: str = Field(..., description="New name for the session") | |
# Session management helper functions | |
SESSIONS_FILE = "sessions.json" | |
def load_sessions() -> List[Dict[str, Any]]: | |
"""Load sessions from JSON file""" | |
try: | |
if os.path.exists(SESSIONS_FILE): | |
with open(SESSIONS_FILE, "r", encoding="utf-8") as f: | |
return json.load(f) | |
return [] | |
except Exception as e: | |
logger.error(f"Error loading sessions: {str(e)}") | |
return [] | |
def save_sessions(sessions: List[Dict[str, Any]]): | |
"""Save sessions to JSON file""" | |
try: | |
with open(SESSIONS_FILE, "w", encoding="utf-8") as f: | |
json.dump(sessions, f, ensure_ascii=False, indent=2, default=str) | |
except Exception as e: | |
logger.error(f"Error saving sessions: {str(e)}") | |
def create_session(name: str) -> Dict[str, Any]: | |
"""Create a new session""" | |
session_id = str(uuid.uuid4()) | |
session = { | |
"id": session_id, | |
"name": name, | |
"created_at": datetime.now().isoformat(), | |
"last_message": None, | |
"message_count": 0, | |
} | |
sessions = load_sessions() | |
sessions.append(session) | |
save_sessions(sessions) | |
return session | |
def get_session_by_id(session_id: str) -> Optional[Dict[str, Any]]: | |
"""Get session by ID""" | |
sessions = load_sessions() | |
return next((s for s in sessions if s["id"] == session_id), None) | |
def update_session_last_message(session_id: str, message: str): | |
"""Update session's last message""" | |
sessions = load_sessions() | |
for session in sessions: | |
if session["id"] == session_id: | |
session["last_message"] = message | |
session["message_count"] = session.get("message_count", 0) + 1 | |
break | |
save_sessions(sessions) | |
def delete_session_by_id(session_id: str) -> bool: | |
"""Delete session by ID""" | |
sessions = load_sessions() | |
original_count = len(sessions) | |
sessions = [s for s in sessions if s["id"] != session_id] | |
if len(sessions) < original_count: | |
save_sessions(sessions) | |
return True | |
return False | |
async def list_scenarios(): | |
"""Get all available scenarios""" | |
return JSONResponse(content=get_scenarios()) | |
async def roleplay(request: RoleplayRequest): | |
"""Send a message to the roleplay agent""" | |
scenario = request.scenario | |
if not scenario: | |
raise HTTPException(status_code=400, detail="Scenario not provided") | |
response = await create_agents(scenario).ainvoke( | |
{ | |
"messages": [request.query], | |
}, | |
{"configurable": {"thread_id": request.session_id}}, | |
) | |
# Update session with last message | |
update_session_last_message(request.session_id, request.query) | |
return JSONResponse(content=response["messages"][-1].content) | |
async def get_messages(request: SessionRequest): | |
"""Get all messages from a conversation session""" | |
try: | |
# Create agent instance | |
agent = create_agents() | |
# Get current state | |
current_state = agent.get_state( | |
{"configurable": {"thread_id": request.session_id}} | |
) | |
if not current_state or not current_state.values: | |
return JSONResponse( | |
content={ | |
"session_id": request.session_id, | |
"messages": [], | |
"total_messages": 0, | |
} | |
) | |
# Extract messages from state | |
messages = [] | |
if "messages" in current_state.values: | |
raw_messages = current_state.values["messages"] | |
for msg in raw_messages: | |
# Convert message object to dict format | |
if hasattr(msg, "content") and hasattr(msg, "type"): | |
messages.append( | |
{ | |
"role": getattr(msg, "type", "unknown"), | |
"content": getattr(msg, "content", ""), | |
"timestamp": getattr(msg, "timestamp", None), | |
} | |
) | |
elif hasattr(msg, "content"): | |
# Handle different message formats | |
role = ( | |
"human" | |
if hasattr(msg, "__class__") | |
and "Human" in msg.__class__.__name__ | |
else "ai" | |
) | |
messages.append( | |
{ | |
"role": role, | |
"content": msg.content, | |
"timestamp": getattr(msg, "timestamp", None), | |
} | |
) | |
else: | |
# Fallback for unexpected message format | |
messages.append( | |
{"role": "unknown", "content": str(msg), "timestamp": None} | |
) | |
return JSONResponse( | |
content={ | |
"session_id": request.session_id, | |
"messages": messages, | |
"total_messages": len(messages), | |
} | |
) | |
except Exception as e: | |
logger.error( | |
f"Error getting messages for session {request.session_id}: {str(e)}" | |
) | |
raise HTTPException(status_code=500, detail=f"Failed to get messages: {str(e)}") | |
async def get_sessions(): | |
"""Get all sessions""" | |
try: | |
sessions = load_sessions() | |
return JSONResponse(content={"sessions": sessions}) | |
except Exception as e: | |
logger.error(f"Error getting sessions: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to get sessions: {str(e)}") | |
async def create_new_session(request: CreateSessionRequest): | |
"""Create a new session""" | |
try: | |
session = create_session(request.name) | |
return JSONResponse(content={"session": session}) | |
except Exception as e: | |
logger.error(f"Error creating session: {str(e)}") | |
raise HTTPException( | |
status_code=500, detail=f"Failed to create session: {str(e)}" | |
) | |
async def get_session(session_id: str): | |
"""Get a specific session by ID""" | |
try: | |
session = get_session_by_id(session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail="Session not found") | |
return JSONResponse(content={"session": session}) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error getting session {session_id}: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to get session: {str(e)}") | |
async def update_session(session_id: str, request: UpdateSessionRequest): | |
"""Update a session""" | |
try: | |
sessions = load_sessions() | |
session_found = False | |
for session in sessions: | |
if session["id"] == session_id: | |
session["name"] = request.name | |
session_found = True | |
break | |
if not session_found: | |
raise HTTPException(status_code=404, detail="Session not found") | |
save_sessions(sessions) | |
updated_session = get_session_by_id(session_id) | |
return JSONResponse(content={"session": updated_session}) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error updating session {session_id}: {str(e)}") | |
raise HTTPException( | |
status_code=500, detail=f"Failed to update session: {str(e)}" | |
) | |
async def delete_session(session_id: str): | |
"""Delete a session""" | |
try: | |
success = delete_session_by_id(session_id) | |
if not success: | |
raise HTTPException(status_code=404, detail="Session not found") | |
return JSONResponse(content={"message": "Session deleted successfully"}) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error deleting session {session_id}: {str(e)}") | |
raise HTTPException( | |
status_code=500, detail=f"Failed to delete session: {str(e)}" | |
) | |