Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| import warnings | |
| from dotenv import load_dotenv | |
| from fastapi import Depends, FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from starlette.requests import Request | |
| load_dotenv() | |
| from app.runtime_env import apply_runtime_env_defaults | |
| apply_runtime_env_defaults() | |
| os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1" | |
| warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") | |
| warnings.filterwarnings("ignore", category=UserWarning, module="albumentations") | |
| from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging | |
| setup_application_logging() | |
| # Routers (after logging) | |
| from app.dependencies import get_current_user_id | |
| from app.ocr_local_file import ocr_from_local_image_path | |
| from app.routers import auth, sessions, solve | |
| from agents.ocr_agent import OCRAgent | |
| from app.routers.solve import get_orchestrator | |
| from app.job_poll import normalize_job_row_for_client | |
| from app.supabase_client import get_supabase | |
| from app.websocket_manager import register_websocket_routes | |
| logger = logging.getLogger("app.main") | |
| _access = logging.getLogger(ACCESS_LOGGER_NAME) | |
| app = FastAPI(title="Visual Math Solver API v5.1") | |
| async def access_log_middleware(request: Request, call_next): | |
| """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error.""" | |
| start = time.perf_counter() | |
| response = await call_next(request) | |
| ms = (time.perf_counter() - start) * 1000 | |
| mode = get_log_level() | |
| method = request.method | |
| path = request.url.path | |
| status = response.status_code | |
| if mode in ("debug", "info"): | |
| _access.info("%s %s -> %s (%.0fms)", method, path, status, ms) | |
| elif mode == "warning": | |
| if status >= 500: | |
| _access.error("%s %s -> %s (%.0fms)", method, path, status, ms) | |
| elif status >= 400: | |
| _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms) | |
| elif mode == "error": | |
| if status >= 400: | |
| _access.error("%s %s -> %s", method, path, status) | |
| return response | |
| from worker.celery_app import BROKER_URL | |
| _broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL | |
| if get_log_level() in ("debug", "info"): | |
| logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail) | |
| else: | |
| logger.warning( | |
| "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "http://localhost:3005", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.include_router(auth.router) | |
| app.include_router(sessions.router) | |
| app.include_router(solve.router) | |
| register_websocket_routes(app) | |
| def get_ocr_agent() -> OCRAgent: | |
| """Same OCR instance as the solve pipeline (no duplicate model load).""" | |
| return get_orchestrator().ocr_agent | |
| supabase_client = get_supabase() | |
| def read_root(): | |
| return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"} | |
| async def upload_ocr( | |
| file: UploadFile = File(...), | |
| _user_id=Depends(get_current_user_id), | |
| ): | |
| """OCR upload: requires authenticated user.""" | |
| temp_path = f"temp_{uuid.uuid4()}.png" | |
| with open(temp_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| try: | |
| text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent()) | |
| return {"text": text} | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| async def get_job_status( | |
| job_id: str, | |
| user_id=Depends(get_current_user_id), | |
| ): | |
| """Retrieve job status (can be used for polling if WS fails). Owner-only.""" | |
| response = supabase_client.table("jobs").select("*").eq("id", job_id).execute() | |
| if not response.data: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| job = response.data[0] | |
| if job.get("user_id") is not None and str(job["user_id"]) != str(user_id): | |
| raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.") | |
| # Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs) | |
| return normalize_job_row_for_client(job) | |