Spaces:
Sleeping
Sleeping
| from fastapi import ( | |
| APIRouter, | |
| status, | |
| Depends, | |
| BackgroundTasks, | |
| HTTPException, | |
| File, | |
| UploadFile, | |
| Form, | |
| ) | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from src.utils.logger import logger | |
| from src.services.tts_service import tts_service | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict, Any, Optional | |
| from src.agents.lesson_practice_2.flow import lesson_practice_2_agent | |
| from src.apis.models.lesson_models import Lesson, LessonResponse, LessonDetailResponse | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime | |
| import base64 | |
| import asyncio | |
| router = APIRouter(prefix="/lesson", tags=["AI"]) | |
| class LessonPracticeRequest(BaseModel): | |
| unit: str = Field(..., description="Unit of the lesson") | |
| vocabulary: list = Field(..., description="Vocabulary for the lesson") | |
| key_structures: list = Field(..., description="Key structures for the lesson") | |
| practice_questions: list = Field( | |
| ..., description="Practice questions for the lesson" | |
| ) | |
| student_level: str = Field("beginner", description="Student's level of English") | |
| query: str = Field(..., description="User query for the lesson") | |
| session_id: str = Field(..., description="Session ID for the lesson") | |
| # Helper function to load lessons from JSON file | |
| def load_lessons_from_file() -> List[Lesson]: | |
| """Load lessons from the JSON file""" | |
| try: | |
| lessons_file_path = os.path.join( | |
| os.path.dirname(__file__), "..", "..", "data", "lessons.json" | |
| ) | |
| if not os.path.exists(lessons_file_path): | |
| logger.warning(f"Lessons file not found at {lessons_file_path}") | |
| return [] | |
| with open(lessons_file_path, "r", encoding="utf-8") as file: | |
| lessons_data = json.load(file) | |
| # Convert to Lesson objects | |
| lessons = [] | |
| for lesson_data in lessons_data: | |
| try: | |
| lesson = Lesson(**lesson_data) | |
| lessons.append(lesson) | |
| except Exception as e: | |
| logger.error( | |
| f"Error parsing lesson {lesson_data.get('id', 'unknown')}: {str(e)}" | |
| ) | |
| continue | |
| return lessons | |
| except Exception as e: | |
| logger.error(f"Error loading lessons: {str(e)}") | |
| return [] | |
| async def get_all_lessons(): | |
| """ | |
| Get all available lessons | |
| Returns: | |
| LessonResponse: Contains list of all lessons and total count | |
| """ | |
| try: | |
| lessons = load_lessons_from_file() | |
| return LessonResponse(lessons=lessons, total=len(lessons)) | |
| except Exception as e: | |
| logger.error(f"Error retrieving lessons: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to retrieve lessons", | |
| ) | |
| async def get_lesson_by_id(lesson_id: str): | |
| """ | |
| Get a specific lesson by ID | |
| Args: | |
| lesson_id (str): The unique identifier of the lesson | |
| Returns: | |
| LessonDetailResponse: Contains the lesson details | |
| """ | |
| try: | |
| lessons = load_lessons_from_file() | |
| # Find the lesson with the specified ID | |
| lesson = next((l for l in lessons if l.id == lesson_id), None) | |
| if not lesson: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Lesson with ID '{lesson_id}' not found", | |
| ) | |
| return LessonDetailResponse(lesson=lesson) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error retrieving lesson {lesson_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to retrieve lesson", | |
| ) | |
| async def search_lessons_by_unit(unit_name: str): | |
| """ | |
| Search lessons by unit name (case-insensitive partial match) | |
| Args: | |
| unit_name (str): Part of the unit name to search for | |
| Returns: | |
| LessonResponse: Contains list of matching lessons | |
| """ | |
| try: | |
| lessons = load_lessons_from_file() | |
| # Filter lessons by unit name (case-insensitive partial match) | |
| matching_lessons = [ | |
| lesson for lesson in lessons if unit_name.lower() in lesson.unit.lower() | |
| ] | |
| return LessonResponse(lessons=matching_lessons, total=len(matching_lessons)) | |
| except Exception as e: | |
| logger.error(f"Error searching lessons by unit '{unit_name}': {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Failed to search lessons", | |
| ) | |
| async def chat( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| lesson_data: str = Form(..., description="The lesson data as JSON string"), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| ): | |
| """Send a message (text or audio) to the lesson practice v2 agent with Practice and Teaching agents""" | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse lesson data from JSON string | |
| try: | |
| lesson_dict = json.loads(lesson_data) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format") | |
| if not lesson_dict: | |
| raise HTTPException(status_code=400, detail="Lesson data not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| try: | |
| response = await lesson_practice_2_agent().ainvoke( | |
| { | |
| "messages": [message], | |
| "unit": lesson_dict.get("unit", ""), | |
| "vocabulary": lesson_dict.get("vocabulary", []), | |
| "key_structures": lesson_dict.get("key_structures", []), | |
| "practice_questions": lesson_dict.get("practice_questions", []), | |
| "student_level": lesson_dict.get("student_level", "beginner"), | |
| }, | |
| {"configurable": {"thread_id": session_id}}, | |
| ) | |
| # Extract AI response content | |
| ai_response = response["messages"][-1].content | |
| logger.info(f"AI response (v2): {ai_response}") | |
| return JSONResponse(content={"response": ai_response}) | |
| except Exception as e: | |
| logger.error(f"Error in lesson practice v2: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def chat_stream( | |
| session_id: str = Form( | |
| ..., description="Session ID for tracking user interactions" | |
| ), | |
| lesson_data: str = Form(..., description="The lesson data as JSON string"), | |
| text_message: Optional[str] = Form(None, description="Text message from user"), | |
| audio_file: Optional[UploadFile] = File(None, description="Audio file from user"), | |
| audio: bool = Form(False, description="Whether to return TTS audio response"), | |
| ): | |
| """Send a message (text or audio) to the lesson practice v2 agent with streaming response""" | |
| logger.info(f"Received streaming lesson practice v2 request: {session_id}") | |
| # Validate that at least one input is provided | |
| if not text_message and not audio_file: | |
| raise HTTPException( | |
| status_code=400, detail="Either text_message or audio_file must be provided" | |
| ) | |
| # Parse lesson data from JSON string | |
| try: | |
| lesson_dict = json.loads(lesson_data) | |
| except json.JSONDecodeError: | |
| raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format") | |
| if not lesson_dict: | |
| raise HTTPException(status_code=400, detail="Lesson data not provided") | |
| # Prepare message content | |
| message_content = [] | |
| # Handle text input | |
| if text_message: | |
| message_content.append({"type": "text", "text": text_message}) | |
| # Handle audio input | |
| if audio_file: | |
| try: | |
| # Read audio file content | |
| audio_data = await audio_file.read() | |
| # Convert to base64 | |
| audio_base64 = base64.b64encode(audio_data).decode("utf-8") | |
| # Determine mime type based on file extension | |
| file_extension = ( | |
| audio_file.filename.split(".")[-1].lower() | |
| if audio_file.filename | |
| else "wav" | |
| ) | |
| mime_type_map = { | |
| "wav": "audio/wav", | |
| "mp3": "audio/mpeg", | |
| "ogg": "audio/ogg", | |
| "webm": "audio/webm", | |
| "m4a": "audio/mp4", | |
| } | |
| mime_type = mime_type_map.get(file_extension, "audio/wav") | |
| message_content.append( | |
| { | |
| "type": "audio", | |
| "source_type": "base64", | |
| "data": audio_base64, | |
| "mime_type": mime_type, | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {str(e)}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Error processing audio file: {str(e)}" | |
| ) | |
| # Create message in the required format | |
| message = {"role": "user", "content": message_content} | |
| async def generate_stream(): | |
| """Generator function for streaming responses""" | |
| accumulated_content = "" | |
| try: | |
| input_graph = { | |
| "messages": [message], | |
| "unit": lesson_dict.get("unit", ""), | |
| "vocabulary": lesson_dict.get("vocabulary", []), | |
| "key_structures": lesson_dict.get("key_structures", []), | |
| "practice_questions": lesson_dict.get("practice_questions", []), | |
| "student_level": lesson_dict.get("student_level", "beginner"), | |
| } | |
| config = {"configurable": {"thread_id": session_id}} | |
| async for event in lesson_practice_2_agent().astream( | |
| input=input_graph, | |
| stream_mode=["messages"], | |
| config=config, | |
| subgraphs=True, | |
| ): | |
| _, event_type, message_chunk = event | |
| if event_type == "messages": | |
| # message_chunk is a tuple, get the first element which is the actual AIMessageChunk | |
| if isinstance(message_chunk, tuple) and len(message_chunk) > 0: | |
| actual_message = message_chunk[0] | |
| content = getattr(actual_message, "content", "") | |
| else: | |
| actual_message = message_chunk | |
| content = getattr(message_chunk, "content", "") | |
| if content: | |
| # Accumulate content for TTS | |
| accumulated_content += content | |
| # Create SSE-formatted response | |
| response_data = { | |
| "type": "message_chunk", | |
| "content": content, | |
| "metadata": { | |
| "agent": getattr(actual_message, "name", "unknown"), | |
| "id": getattr(actual_message, "id", ""), | |
| "usage_metadata": getattr( | |
| actual_message, "usage_metadata", {} | |
| ), | |
| }, | |
| } | |
| yield f"data: {json.dumps(response_data)}\n\n" | |
| # Small delay to prevent overwhelming the client | |
| await asyncio.sleep(0.01) | |
| # Generate TTS audio if requested | |
| audio_data = None | |
| if audio and accumulated_content.strip(): | |
| try: | |
| logger.info( | |
| f"Generating TTS for lesson v2 content: {len(accumulated_content)} chars" | |
| ) | |
| audio_result = await tts_service.text_to_speech(accumulated_content) | |
| if audio_result: | |
| audio_data = { | |
| "audio_data": audio_result["audio_data"], | |
| "mime_type": audio_result["mime_type"], | |
| "format": audio_result["format"], | |
| } | |
| logger.info("Lesson v2 TTS audio generated successfully") | |
| else: | |
| logger.warning("Lesson v2 TTS generation failed") | |
| except Exception as tts_error: | |
| logger.error(f"Lesson v2 TTS generation error: {str(tts_error)}") | |
| # Send completion signal with optional audio | |
| completion_data = {"type": "completion", "content": "", "audio": audio_data} | |
| yield f"data: {json.dumps(completion_data)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in streaming lesson practice v2: {str(e)}") | |
| error_data = {"type": "error", "content": str(e)} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream", | |
| }, | |
| ) | |