Run_code_api / src /apis /routes /lesson_route.py
ABAO77's picture
feat: evaluation when end
61e4b1e
raw
history blame
15.2 kB
from fastapi import (
APIRouter,
status,
Depends,
BackgroundTasks,
HTTPException,
File,
UploadFile,
Form,
)
from fastapi.responses import JSONResponse, StreamingResponse
from src.utils.logger import logger
from src.services.tts_service import tts_service
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
from src.agents.lesson_practice_2.flow import lesson_practice_2_agent
from src.apis.models.lesson_models import Lesson, LessonResponse, LessonDetailResponse
import json
import os
import uuid
from datetime import datetime
import base64
import asyncio
router = APIRouter(prefix="/lesson", tags=["AI"])
class LessonPracticeRequest(BaseModel):
unit: str = Field(..., description="Unit of the lesson")
vocabulary: list = Field(..., description="Vocabulary for the lesson")
key_structures: list = Field(..., description="Key structures for the lesson")
practice_questions: list = Field(
..., description="Practice questions for the lesson"
)
student_level: str = Field("beginner", description="Student's level of English")
query: str = Field(..., description="User query for the lesson")
session_id: str = Field(..., description="Session ID for the lesson")
# Helper function to load lessons from JSON file
def load_lessons_from_file() -> List[Lesson]:
"""Load lessons from the JSON file"""
try:
lessons_file_path = os.path.join(
os.path.dirname(__file__), "..", "..", "data", "lessons.json"
)
if not os.path.exists(lessons_file_path):
logger.warning(f"Lessons file not found at {lessons_file_path}")
return []
with open(lessons_file_path, "r", encoding="utf-8") as file:
lessons_data = json.load(file)
# Convert to Lesson objects
lessons = []
for lesson_data in lessons_data:
try:
lesson = Lesson(**lesson_data)
lessons.append(lesson)
except Exception as e:
logger.error(
f"Error parsing lesson {lesson_data.get('id', 'unknown')}: {str(e)}"
)
continue
return lessons
except Exception as e:
logger.error(f"Error loading lessons: {str(e)}")
return []
@router.get("/all", response_model=LessonResponse)
async def get_all_lessons():
"""
Get all available lessons
Returns:
LessonResponse: Contains list of all lessons and total count
"""
try:
lessons = load_lessons_from_file()
return LessonResponse(lessons=lessons, total=len(lessons))
except Exception as e:
logger.error(f"Error retrieving lessons: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve lessons",
)
@router.get("/{lesson_id}", response_model=LessonDetailResponse)
async def get_lesson_by_id(lesson_id: str):
"""
Get a specific lesson by ID
Args:
lesson_id (str): The unique identifier of the lesson
Returns:
LessonDetailResponse: Contains the lesson details
"""
try:
lessons = load_lessons_from_file()
# Find the lesson with the specified ID
lesson = next((l for l in lessons if l.id == lesson_id), None)
if not lesson:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Lesson with ID '{lesson_id}' not found",
)
return LessonDetailResponse(lesson=lesson)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving lesson {lesson_id}: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve lesson",
)
@router.get("/search/unit/{unit_name}")
async def search_lessons_by_unit(unit_name: str):
"""
Search lessons by unit name (case-insensitive partial match)
Args:
unit_name (str): Part of the unit name to search for
Returns:
LessonResponse: Contains list of matching lessons
"""
try:
lessons = load_lessons_from_file()
# Filter lessons by unit name (case-insensitive partial match)
matching_lessons = [
lesson for lesson in lessons if unit_name.lower() in lesson.unit.lower()
]
return LessonResponse(lessons=matching_lessons, total=len(matching_lessons))
except Exception as e:
logger.error(f"Error searching lessons by unit '{unit_name}': {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to search lessons",
)
@router.post("/chat")
async def chat(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
lesson_data: str = Form(..., description="The lesson data as JSON string"),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
):
"""Send a message (text or audio) to the lesson practice v2 agent with Practice and Teaching agents"""
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse lesson data from JSON string
try:
lesson_dict = json.loads(lesson_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format")
if not lesson_dict:
raise HTTPException(status_code=400, detail="Lesson data not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
try:
response = await lesson_practice_2_agent().ainvoke(
{
"messages": [message],
"unit": lesson_dict.get("unit", ""),
"vocabulary": lesson_dict.get("vocabulary", []),
"key_structures": lesson_dict.get("key_structures", []),
"practice_questions": lesson_dict.get("practice_questions", []),
"student_level": lesson_dict.get("student_level", "beginner"),
},
{"configurable": {"thread_id": session_id}},
)
# Extract AI response content
ai_response = response["messages"][-1].content
logger.info(f"AI response (v2): {ai_response}")
return JSONResponse(content={"response": ai_response})
except Exception as e:
logger.error(f"Error in lesson practice v2: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/chat/stream", status_code=status.HTTP_200_OK)
async def chat_stream(
session_id: str = Form(
..., description="Session ID for tracking user interactions"
),
lesson_data: str = Form(..., description="The lesson data as JSON string"),
text_message: Optional[str] = Form(None, description="Text message from user"),
audio_file: Optional[UploadFile] = File(None, description="Audio file from user"),
audio: bool = Form(False, description="Whether to return TTS audio response"),
):
"""Send a message (text or audio) to the lesson practice v2 agent with streaming response"""
logger.info(f"Received streaming lesson practice v2 request: {session_id}")
# Validate that at least one input is provided
if not text_message and not audio_file:
raise HTTPException(
status_code=400, detail="Either text_message or audio_file must be provided"
)
# Parse lesson data from JSON string
try:
lesson_dict = json.loads(lesson_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid lesson_data JSON format")
if not lesson_dict:
raise HTTPException(status_code=400, detail="Lesson data not provided")
# Prepare message content
message_content = []
# Handle text input
if text_message:
message_content.append({"type": "text", "text": text_message})
# Handle audio input
if audio_file:
try:
# Read audio file content
audio_data = await audio_file.read()
# Convert to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
# Determine mime type based on file extension
file_extension = (
audio_file.filename.split(".")[-1].lower()
if audio_file.filename
else "wav"
)
mime_type_map = {
"wav": "audio/wav",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"webm": "audio/webm",
"m4a": "audio/mp4",
}
mime_type = mime_type_map.get(file_extension, "audio/wav")
message_content.append(
{
"type": "audio",
"source_type": "base64",
"data": audio_base64,
"mime_type": mime_type,
}
)
except Exception as e:
logger.error(f"Error processing audio file: {str(e)}")
raise HTTPException(
status_code=400, detail=f"Error processing audio file: {str(e)}"
)
# Create message in the required format
message = {"role": "user", "content": message_content}
async def generate_stream():
"""Generator function for streaming responses"""
accumulated_content = ""
try:
input_graph = {
"messages": [message],
"unit": lesson_dict.get("unit", ""),
"vocabulary": lesson_dict.get("vocabulary", []),
"key_structures": lesson_dict.get("key_structures", []),
"practice_questions": lesson_dict.get("practice_questions", []),
"student_level": lesson_dict.get("student_level", "beginner"),
}
config = {"configurable": {"thread_id": session_id}}
async for event in lesson_practice_2_agent().astream(
input=input_graph,
stream_mode=["messages"],
config=config,
subgraphs=True,
):
_, event_type, message_chunk = event
if event_type == "messages":
# message_chunk is a tuple, get the first element which is the actual AIMessageChunk
if isinstance(message_chunk, tuple) and len(message_chunk) > 0:
actual_message = message_chunk[0]
content = getattr(actual_message, "content", "")
else:
actual_message = message_chunk
content = getattr(message_chunk, "content", "")
if content:
# Accumulate content for TTS
accumulated_content += content
# Create SSE-formatted response
response_data = {
"type": "message_chunk",
"content": content,
"metadata": {
"agent": getattr(actual_message, "name", "unknown"),
"id": getattr(actual_message, "id", ""),
"usage_metadata": getattr(
actual_message, "usage_metadata", {}
),
},
}
yield f"data: {json.dumps(response_data)}\n\n"
# Small delay to prevent overwhelming the client
await asyncio.sleep(0.01)
# Generate TTS audio if requested
audio_data = None
if audio and accumulated_content.strip():
try:
logger.info(
f"Generating TTS for lesson v2 content: {len(accumulated_content)} chars"
)
audio_result = await tts_service.text_to_speech(accumulated_content)
if audio_result:
audio_data = {
"audio_data": audio_result["audio_data"],
"mime_type": audio_result["mime_type"],
"format": audio_result["format"],
}
logger.info("Lesson v2 TTS audio generated successfully")
else:
logger.warning("Lesson v2 TTS generation failed")
except Exception as tts_error:
logger.error(f"Lesson v2 TTS generation error: {str(tts_error)}")
# Send completion signal with optional audio
completion_data = {"type": "completion", "content": "", "audio": audio_data}
yield f"data: {json.dumps(completion_data)}\n\n"
except Exception as e:
logger.error(f"Error in streaming lesson practice v2: {str(e)}")
error_data = {"type": "error", "content": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
},
)