Spaces:
Sleeping
Sleeping
| from fastapi import ( | |
| APIRouter, | |
| status, | |
| Depends, | |
| BackgroundTasks, | |
| HTTPException, | |
| File, | |
| UploadFile, | |
| Form, | |
| ) | |
| from fastapi.responses import JSONResponse | |
| from src.utils.logger import logger | |
| from src.agents.role_play.func import create_agents | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any, Optional | |
| from src.agents.role_play.scenarios import get_scenarios, get_scenario_by_id | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime | |
| import base64 | |
| router = APIRouter(prefix="/ai", tags=["AI"]) | |
| class RoleplayRequest(BaseModel): | |
| query: str = Field(..., description="User's query for the AI agent") | |
| session_id: str = Field( | |
| ..., description="Session ID for tracking user interactions" | |
| ) | |
| scenario: Dict[str, Any] = Field(..., description="The scenario for the roleplay") | |
| class SessionRequest(BaseModel): | |
| session_id: str = Field(..., description="Session ID to perform operations on") | |
| class CreateSessionRequest(BaseModel): | |
| name: str = Field(..., description="Name for the new session") | |
| class UpdateSessionRequest(BaseModel): | |
| session_id: str = Field(..., description="Session ID to update") | |
| name: str = Field(..., description="New name for the session") | |
| # Session management helper functions | |
| SESSIONS_FILE = "sessions.json" | |
| def load_sessions() -> List[Dict[str, Any]]: | |
| """Load sessions from JSON file""" | |
| try: | |
| if os.path.exists(SESSIONS_FILE): | |
| with open(SESSIONS_FILE, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error loading sessions: {str(e)}") | |
| return [] | |
| def save_sessions(sessions: List[Dict[str, Any]]): | |
| """Save sessions to JSON file""" | |
| try: | |
| with open(SESSIONS_FILE, "w", encoding="utf-8") as f: | |
| json.dump(sessions, f, ensure_ascii=False, indent=2, default=str) | |
| except Exception as e: | |
| logger.error(f"Error saving sessions: {str(e)}") | |
| def create_session(name: str) -> Dict[str, Any]: | |
| """Create a new session""" | |
| session_id = str(uuid.uuid4()) | |
| session = { | |
| "id": session_id, | |
| "name": name, | |
| "created_at": datetime.now().isoformat(), | |
| "last_message": None, | |
| "message_count": 0, | |
| } | |
| sessions = load_sessions() | |
| sessions.append(session) | |
| save_sessions(sessions) | |
| return session | |
| def get_session_by_id(session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get session by ID""" | |
| sessions = load_sessions() | |
| return next((s for s in sessions if s["id"] == session_id), None) | |
| def update_session_last_message(session_id: str, message: str): | |
| """Update session's last message""" | |
| sessions = load_sessions() | |
| for session in sessions: | |
| if session["id"] == session_id: | |
| session["last_message"] = message | |
| session["message_count"] = session.get("message_count", 0) + 1 | |
| break | |
| save_sessions(sessions) | |
| def delete_session_by_id(session_id: str) -> bool: | |
| """Delete session by ID""" | |
| sessions = load_sessions() | |
| original_count = len(sessions) | |
| sessions = [s for s in sessions if s["id"] != session_id] | |
| if len(sessions) < original_count: | |
| save_sessions(sessions) | |
| return True | |
| return False | |
| async def list_scenarios(): | |
| """Get all available scenarios""" | |
| return JSONResponse(content=get_scenarios()) | |
| async def roleplay( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| scenario: str = Form( | |
| ..., description="The scenario for the roleplay as JSON string" | |
| ), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| ): | |
| """Send a message (text or audio) to the roleplay agent""" | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse scenario from JSON string | |
| try: | |
| scenario_dict = json.loads(scenario) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid scenario JSON format") | |
| if not scenario_dict: | |
| raise HTTPException(status_code=400, detail="Scenario not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| try: | |
| response = await create_agents(scenario_dict).ainvoke( | |
| { | |
| "messages": [message], | |
| }, | |
| {"configurable": {"thread_id": session_id}}, | |
| ) | |
| # Update session with last message (use text if available, otherwise indicate audio) | |
| last_message = text_message if text_message else "[Audio message]" | |
| update_session_last_message(session_id, last_message) | |
| # Extract AI response content | |
| ai_response = response["messages"][-1].content | |
| logger.info(f"AI response: {ai_response}") | |
| return JSONResponse(content={"response": ai_response}) | |
| except Exception as e: | |
| logger.error(f"Error in roleplay: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def get_messages(request: SessionRequest): | |
| """Get all messages from a conversation session""" | |
| try: | |
| # Create agent instance | |
| agent = create_agents() | |
| # Get current state | |
| current_state = agent.get_state( | |
| {"configurable": {"thread_id": request.session_id}} | |
| ) | |
| if not current_state or not current_state.values: | |
| return JSONResponse( | |
| content={ | |
| "session_id": request.session_id, | |
| "messages": [], | |
| "total_messages": 0, | |
| } | |
| ) | |
| # Extract messages from state | |
| messages = [] | |
| if "messages" in current_state.values: | |
| raw_messages = current_state.values["messages"] | |
| for msg in raw_messages: | |
| # Convert message object to dict format | |
| if hasattr(msg, "content") and hasattr(msg, "type"): | |
| messages.append( | |
| { | |
| "role": getattr(msg, "type", "unknown"), | |
| "content": getattr(msg, "content", ""), | |
| "timestamp": getattr(msg, "timestamp", None), | |
| } | |
| ) | |
| elif hasattr(msg, "content"): | |
| # Handle different message formats | |
| role = ( | |
| "human" | |
| if hasattr(msg, "__class__") | |
| and "Human" in msg.__class__.__name__ | |
| else "ai" | |
| ) | |
| messages.append( | |
| { | |
| "role": role, | |
| "content": msg.content, | |
| "timestamp": getattr(msg, "timestamp", None), | |
| } | |
| ) | |
| else: | |
| # Fallback for unexpected message format | |
| messages.append( | |
| {"role": "unknown", "content": str(msg), "timestamp": None} | |
| ) | |
| return JSONResponse( | |
| content={ | |
| "session_id": request.session_id, | |
| "messages": messages, | |
| "total_messages": len(messages), | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error getting messages for session {request.session_id}: {str(e)}" | |
| ) | |
| raise HTTPException(status_code=500, detail=f"Failed to get messages: {str(e)}") | |
| async def get_sessions(): | |
| """Get all sessions""" | |
| try: | |
| sessions = load_sessions() | |
| return JSONResponse(content={"sessions": sessions}) | |
| except Exception as e: | |
| logger.error(f"Error getting sessions: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get sessions: {str(e)}") | |
| async def create_new_session(request: CreateSessionRequest): | |
| """Create a new session""" | |
| try: | |
| session = create_session(request.name) | |
| return JSONResponse(content={"session": session}) | |
| except Exception as e: | |
| logger.error(f"Error creating session: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to create session: {str(e)}" | |
| ) | |
| async def get_session(session_id: str): | |
| """Get a specific session by ID""" | |
| try: | |
| session = get_session_by_id(session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return JSONResponse(content={"session": session}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting session {session_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get session: {str(e)}") | |
| async def update_session(session_id: str, request: UpdateSessionRequest): | |
| """Update a session""" | |
| try: | |
| sessions = load_sessions() | |
| session_found = False | |
| for session in sessions: | |
| if session["id"] == session_id: | |
| session["name"] = request.name | |
| session_found = True | |
| break | |
| if not session_found: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| save_sessions(sessions) | |
| updated_session = get_session_by_id(session_id) | |
| return JSONResponse(content={"session": updated_session}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error updating session {session_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to update session: {str(e)}" | |
| ) | |
| async def delete_session(session_id: str): | |
| """Delete a session""" | |
| try: | |
| success = delete_session_by_id(session_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return JSONResponse(content={"message": "Session deleted successfully"}) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting session {session_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, detail=f"Failed to delete session: {str(e)}" | |
| ) | |