Spaces:
Running
Running
| """ | |
| FastAPI application — REST API chuẩn production cho Study Group Assistant. | |
| Khởi động: | |
| uvicorn src.api:app --host 0.0.0.0 --port 8000 --reload | |
| """ | |
| import logging | |
| import os | |
| import uvicorn | |
| import tempfile | |
| import uuid | |
| from asyncio import get_running_loop | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timezone | |
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| from src.core import final_answer | |
| from src.qdrant_store import get_custom_prompt, save_custom_prompt | |
| from src.redis_client import redis_client | |
| logger = logging.getLogger(__name__) | |
| _executor = ThreadPoolExecutor() | |
| # ── Lifespan ────────────────────────────────────────────────────────────────── | |
| async def lifespan(app: FastAPI): | |
| logger.info("Study Group Assistant API starting up.") | |
| yield | |
| _executor.shutdown(wait=False) | |
| logger.info("Study Group Assistant API shut down.") | |
| # ── App ─────────────────────────────────────────────────────────────────────── | |
| app = FastAPI( | |
| title="Study Group Assistant API", | |
| description=( | |
| "AI agent giúp nhóm học tập tóm tắt hội thoại, " | |
| "tra cứu lịch trình và quản lý ghi nhớ." | |
| ), | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ── Middlewares ─────────────────────────────────────────────────────────────── | |
| async def attach_request_id(request: Request, call_next): | |
| """Gắn X-Request-ID vào mỗi request để dễ trace log.""" | |
| request_id = str(uuid.uuid4()) | |
| request.state.request_id = request_id | |
| response = await call_next(request) | |
| response.headers["X-Request-ID"] = request_id | |
| return response | |
| async def log_requests(request: Request, call_next): | |
| """Log method, path và status code của mỗi request.""" | |
| response = await call_next(request) | |
| logger.info( | |
| "%s %s → %d [rid=%s]", | |
| request.method, | |
| request.url.path, | |
| response.status_code, | |
| getattr(request.state, "request_id", "-"), | |
| ) | |
| return response | |
| # ── Pydantic models ─────────────────────────────────────────────────────────── | |
| class ChatRequest(BaseModel): | |
| conversation_id: str = Field(..., description="ID cuộc hội thoại DM") | |
| sender_id: str = Field(..., description="ID hoặc tên người gửi") | |
| query: str = Field(..., description="Câu hỏi hoặc yêu cầu") | |
| model_config = { | |
| "json_schema_extra": { | |
| "example": { | |
| "conversation_id": "98996225-512c-4491-96a2-bc71552328ca", | |
| "sender_id": "@Hoang", | |
| "query": "Tóm tắt cuộc trò chuyện hôm nay", | |
| } | |
| } | |
| } | |
| class ChatResponse(BaseModel): | |
| answer: str = Field(..., description="Câu trả lời từ agent") | |
| processing_time: str = Field(..., description="Thời gian xử lý, ví dụ '1.23s'") | |
| conversation_id: str | |
| sender_id: str | |
| class HealthComponent(BaseModel): | |
| status: str = Field(..., description="'ok' | 'degraded' | 'down'") | |
| detail: str = "" | |
| class HealthResponse(BaseModel): | |
| status: str = Field(..., description="'ok' | 'degraded'") | |
| timestamp: str | |
| components: dict[str, HealthComponent] | |
| class ErrorDetail(BaseModel): | |
| error: str | |
| detail: str = "" | |
| request_id: str = "" | |
| class CustomPromptRequest(BaseModel): | |
| user_id: str = Field(..., description="ID người dùng") | |
| prompt: str = Field(..., description="Nội dung custom prompt") | |
| model_config = { | |
| "json_schema_extra": { | |
| "example": { | |
| "user_id": "@Hoang", | |
| "prompt": "Luôn trả lời ngắn gọn trong 3 câu. Dùng bullet point khi liệt kê.", | |
| } | |
| } | |
| } | |
| class CustomPromptResponse(BaseModel): | |
| success: bool | |
| user_id: str | |
| prompt: str | |
| # ── Helper ──────────────────────────────────────────────────────────────────── | |
| def _request_id(request: Request) -> str: | |
| return getattr(request.state, "request_id", "") | |
| def _utcnow() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| # ── Routes ──────────────────────────────────────────────────────────────────── | |
| async def root(): | |
| return { | |
| "service": "Study Group Assistant API", | |
| "version": "1.0.0", | |
| "docs": "/docs", | |
| "health": "/health", | |
| } | |
| async def health(): | |
| redis_ok = redis_client.ping() | |
| return HealthResponse( | |
| status="ok" if redis_ok else "degraded", | |
| timestamp=_utcnow(), | |
| components={ | |
| "redis": HealthComponent( | |
| status="ok" if redis_ok else "down", | |
| detail="Connected" if redis_ok else "Connection failed — using local fallback", | |
| ), | |
| "agent": HealthComponent(status="ok"), | |
| }, | |
| ) | |
| async def chat(request: Request, body: ChatRequest): | |
| """ | |
| Gửi query đến agent, nhận câu trả lời và thời gian xử lý. | |
| Agent sẽ tự động: | |
| - Phân loại yêu cầu (trả lời trực tiếp hoặc tra cứu hội thoại) | |
| - Gọi các tool phù hợp (tóm tắt, lịch trình, ghi nhớ, web...) | |
| - Tổng hợp kết quả thành câu trả lời tự nhiên | |
| """ | |
| loop = get_running_loop() | |
| try: | |
| answer, elapsed = await loop.run_in_executor( | |
| _executor, | |
| lambda: final_answer(body.conversation_id, body.sender_id, body.query), | |
| ) | |
| except ValueError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail=str(e), | |
| ) | |
| except Exception as e: | |
| logger.exception( | |
| "Unhandled error in POST /api/v1/chat [rid=%s]", _request_id(request) | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Lỗi xử lý nội bộ. Vui lòng thử lại.", | |
| ) | |
| return ChatResponse( | |
| answer=answer, | |
| processing_time=elapsed, | |
| conversation_id=body.conversation_id, | |
| sender_id=body.sender_id, | |
| ) | |
| async def chat_with_pdf( | |
| request: Request, | |
| conversation_id: str = Form(..., description="ID cuộc hội thoại DM"), | |
| sender_id: str = Form(..., description="ID hoặc tên người gửi"), | |
| query: str = Form(..., description="Câu hỏi hoặc yêu cầu về nội dung PDF"), | |
| file: UploadFile = File(..., description="File PDF cần xử lý"), | |
| ): | |
| if not file.filename.lower().endswith(".pdf"): | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Chỉ chấp nhận file PDF.", | |
| ) | |
| tmp_path = None | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| tmp_path = tmp.name | |
| loop = get_running_loop() | |
| answer, elapsed = await loop.run_in_executor( | |
| _executor, | |
| lambda: final_answer(conversation_id, sender_id, query, pdf_path=tmp_path), | |
| ) | |
| except ValueError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail=str(e), | |
| ) | |
| except Exception: | |
| logger.exception( | |
| "Unhandled error in POST /api/v1/chat_with_pdf [rid=%s]", _request_id(request) | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Lỗi xử lý nội bộ. Vui lòng thử lại.", | |
| ) | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| return ChatResponse( | |
| answer=answer, | |
| processing_time=elapsed, | |
| conversation_id=conversation_id, | |
| sender_id=sender_id, | |
| ) | |
| async def set_custom_prompt(request: Request, body: CustomPromptRequest): | |
| """ | |
| Lưu hoặc cập nhật custom prompt của người dùng lên Qdrant. | |
| Prompt này sẽ được tự động inject vào system prompt khi user đó gửi query. | |
| """ | |
| loop = get_running_loop() | |
| ok = await loop.run_in_executor( | |
| _executor, | |
| lambda: save_custom_prompt(body.user_id, body.prompt), | |
| ) | |
| if not ok: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Không thể lưu custom prompt. Kiểm tra cấu hình QDRANT_URL.", | |
| ) | |
| return CustomPromptResponse(success=True, user_id=body.user_id, prompt=body.prompt) | |
| async def get_user_custom_prompt(user_id: str, request: Request): | |
| loop = get_running_loop() | |
| prompt = await loop.run_in_executor( | |
| _executor, | |
| lambda: get_custom_prompt(user_id), | |
| ) | |
| if prompt is None: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Không tìm thấy custom prompt cho user '{user_id}'.", | |
| ) | |
| return CustomPromptResponse(success=True, user_id=user_id, prompt=prompt) | |
| _IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif", ".bmp"} | |
| async def chat_with_image( | |
| request: Request, | |
| conversation_id: str = Form(..., description="ID cuộc hội thoại DM"), | |
| sender_id: str = Form(..., description="ID hoặc tên người gửi"), | |
| query: str = Form(..., description="Câu hỏi hoặc yêu cầu về nội dung ảnh"), | |
| file: UploadFile = File(..., description="File ảnh cần xử lý"), | |
| ): | |
| ext = os.path.splitext(file.filename.lower())[1] | |
| if ext not in _IMAGE_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Chỉ chấp nhận ảnh: {', '.join(_IMAGE_EXTENSIONS)}.", | |
| ) | |
| tmp_path = None | |
| try: | |
| with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| tmp_path = tmp.name | |
| loop = get_running_loop() | |
| answer, elapsed = await loop.run_in_executor( | |
| _executor, | |
| lambda: final_answer(conversation_id, sender_id, query, image_path=tmp_path), | |
| ) | |
| except ValueError as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| detail=str(e), | |
| ) | |
| except Exception: | |
| logger.exception( | |
| "Unhandled error in POST /api/v1/chat_with_image [rid=%s]", _request_id(request) | |
| ) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Lỗi xử lý nội bộ. Vui lòng thử lại.", | |
| ) | |
| finally: | |
| if tmp_path and os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| return ChatResponse( | |
| answer=answer, | |
| processing_time=elapsed, | |
| conversation_id=conversation_id, | |
| sender_id=sender_id, | |
| ) | |
| # ── Exception handlers ──────────────────────────────────────────────────────── | |
| async def not_found(request: Request, exc): | |
| return JSONResponse( | |
| status_code=404, | |
| content=ErrorDetail( | |
| error="Not Found", | |
| detail=f"Endpoint '{request.url.path}' không tồn tại.", | |
| request_id=_request_id(request), | |
| ).model_dump(), | |
| ) | |
| async def method_not_allowed(request: Request, exc): | |
| return JSONResponse( | |
| status_code=405, | |
| content=ErrorDetail( | |
| error="Method Not Allowed", | |
| detail=f"Method '{request.method}' không được hỗ trợ tại '{request.url.path}'.", | |
| request_id=_request_id(request), | |
| ).model_dump(), | |
| ) | |
| async def internal_error(request: Request, exc): | |
| return JSONResponse( | |
| status_code=500, | |
| content=ErrorDetail( | |
| error="Internal Server Error", | |
| detail="Đã xảy ra lỗi không mong muốn.", | |
| request_id=_request_id(request), | |
| ).model_dump(), | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run("src.api:app", host="127.0.0.1", port=8000, reload=True) | |