Yu Chen
custom csv upload logic
0641d27
"""FastAPI entrypoint for the ClassLens backend (v2 - 3-step workflow)."""
from __future__ import annotations
from pathlib import Path
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, Depends, File, Form, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from .database import (
init_database,
create_session,
get_sessions,
get_session,
update_session_status,
get_parsed_data,
save_prompt,
get_prompts,
get_all_prompts,
get_prompt,
update_prompt,
delete_prompt,
save_report,
get_reports,
get_report,
save_answer_grid,
get_answer_grid,
delete_parsed_data,
ANSWER_GRID_DATA_TYPE,
)
from .auth import get_current_user, register_user, login_user
from .file_processor import process_uploaded_files
from .answer_grid import from_dict as grid_from_dict, seed_from_parsed, to_dict as grid_to_dict
from .parsers import AUTO, list_parsers
from .report_generator import generate_student_report, build_student_markdown
STATIC_DIR = Path(__file__).parent.parent / "static"
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize database on startup."""
await init_database()
yield
app = FastAPI(
title="ClassLens API",
description="AI-powered exam analysis for teachers (v2)",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# Auth Endpoints
# =============================================================================
class AuthRequest(BaseModel):
email: str
password: str
@app.post("/api/auth/register")
async def api_register(body: AuthRequest):
user = await register_user(body.email, body.password)
from .auth import create_access_token
token = create_access_token({"sub": str(user["id"])})
return {"token": token, "user": {"id": user["id"], "email": user["email"], "display_name": user["display_name"]}}
@app.post("/api/auth/login")
async def api_login(body: AuthRequest):
result = await login_user(body.email, body.password)
return result
@app.get("/api/auth/me")
async def api_me(user=Depends(get_current_user)):
return {"id": user["id"], "email": user["email"], "display_name": user["display_name"]}
# =============================================================================
# Session Endpoints
# =============================================================================
class CreateSessionRequest(BaseModel):
title: str = "Untitled Session"
@app.post("/api/sessions")
async def api_create_session(body: CreateSessionRequest, user=Depends(get_current_user)):
session_id = await create_session(user["id"], body.title)
session = await get_session(session_id)
return session
@app.get("/api/sessions")
async def api_list_sessions(user=Depends(get_current_user)):
sessions = await get_sessions(user["id"])
return {"sessions": sessions}
@app.get("/api/sessions/{session_id}")
async def api_get_session(session_id: int, user=Depends(get_current_user)):
session = await get_session(session_id)
if not session or session["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Session not found")
return session
# =============================================================================
# File Upload & Processing Endpoints
# =============================================================================
@app.get("/api/parsers")
async def api_list_parsers():
"""List available parser backends for the upload UI."""
return {"parsers": list_parsers()}
@app.post("/api/sessions/{session_id}/upload")
async def api_upload_files(
session_id: int,
data_type: str = Form(...),
files: list[UploadFile] = File(...),
description: str = Form(""),
model: str = Form("gpt-5.4"),
parser: str = Form(AUTO),
user=Depends(get_current_user),
):
"""Upload files for a session. data_type: 'questions', 'student_answers', or 'teacher_answers'."""
session = await get_session(session_id)
if not session or session["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Session not found")
if data_type not in ("questions", "student_answers", "teacher_answers", "answers"):
raise HTTPException(
status_code=400,
detail="data_type must be 'questions', 'answers', 'student_answers', or 'teacher_answers'",
)
try:
structured = await process_uploaded_files(
files,
data_type,
session_id,
description=description,
model=model,
parser=parser,
)
await update_session_status(session_id, "processed")
# Invalidate any previously-confirmed answer grid since inputs changed
await delete_parsed_data(session_id, ANSWER_GRID_DATA_TYPE)
return {"status": "ok", "data_type": data_type, "data": structured}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
@app.get("/api/sessions/{session_id}/parsed-data")
async def api_get_parsed_data(session_id: int, user=Depends(get_current_user)):
session = await get_session(session_id)
if not session or session["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Session not found")
data = await get_parsed_data(session_id)
# Group by data_type
grouped = {}
for item in data:
dt = item["data_type"]
if dt not in grouped:
grouped[dt] = []
grouped[dt].append(item)
return {"parsed_data": grouped}
# =============================================================================
# Prompt Endpoints
# =============================================================================
class PromptRequest(BaseModel):
name: str
content: str
@app.get("/api/prompts")
async def api_list_prompts(user=Depends(get_current_user)):
prompts = await get_prompts(user["id"])
return {"prompts": prompts}
@app.get("/api/prompts/all")
async def api_list_all_prompts(user=Depends(get_current_user)):
"""List all prompts from all users (read others' prompts)."""
prompts = await get_all_prompts()
return {"prompts": prompts}
@app.post("/api/prompts")
async def api_save_prompt(body: PromptRequest, user=Depends(get_current_user)):
prompt_id = await save_prompt(user["id"], body.name, body.content)
prompt = await get_prompt(prompt_id)
return prompt
@app.get("/api/prompts/{prompt_id}")
async def api_get_prompt(prompt_id: int, user=Depends(get_current_user)):
prompt = await get_prompt(prompt_id)
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
return prompt
@app.put("/api/prompts/{prompt_id}")
async def api_update_prompt(prompt_id: int, body: PromptRequest, user=Depends(get_current_user)):
prompt = await get_prompt(prompt_id)
if not prompt or prompt["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Prompt not found")
await update_prompt(prompt_id, body.name, body.content)
return await get_prompt(prompt_id)
@app.delete("/api/prompts/{prompt_id}")
async def api_delete_prompt(prompt_id: int, user=Depends(get_current_user)):
prompt = await get_prompt(prompt_id)
if not prompt or prompt["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Prompt not found")
await delete_prompt(prompt_id)
return {"status": "deleted"}
# =============================================================================
# Answer Grid Endpoints (canonical data for report generation)
# =============================================================================
class AnswerGridPayload(BaseModel):
total_questions: int
official_answers: list[Optional[str]]
students: list[dict]
questions: list[dict]
async def _session_or_404(session_id: int, user: dict) -> dict:
session = await get_session(session_id)
if not session or session["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Session not found")
return session
@app.get("/api/sessions/{session_id}/answer-grid")
async def api_get_answer_grid(session_id: int, user=Depends(get_current_user)):
"""Return the saved grid, or a seeded grid synthesized from parsed data."""
await _session_or_404(session_id, user)
saved = await get_answer_grid(session_id)
if saved:
return {"grid": saved, "is_confirmed": True}
# Seed from parsed data
data = await get_parsed_data(session_id)
questions_data: dict = {}
student_answers_data: dict = {}
teacher_answers_data: dict = {}
for item in data:
if item["data_type"] == "questions":
questions_data = item["structured_data"]
elif item["data_type"] == "student_answers":
student_answers_data = item["structured_data"]
elif item["data_type"] == "teacher_answers":
teacher_answers_data = item["structured_data"]
grid = seed_from_parsed(questions_data, student_answers_data, teacher_answers_data)
return {"grid": grid_to_dict(grid), "is_confirmed": False}
@app.post("/api/sessions/{session_id}/answer-grid")
async def api_save_answer_grid(
session_id: int, body: AnswerGridPayload, user=Depends(get_current_user)
):
"""Persist a user-confirmed answer grid."""
await _session_or_404(session_id, user)
try:
grid = grid_from_dict(body.model_dump())
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
grid_dict = grid_to_dict(grid)
await save_answer_grid(session_id, grid_dict)
return {"grid": grid_dict, "is_confirmed": True}
# =============================================================================
# Report Generation Endpoints
# =============================================================================
class PreviewStudentPromptRequest(BaseModel):
student_index: int
class GenerateStudentReportRequest(BaseModel):
student_index: int
model: str = "gpt-5.4"
async def _require_confirmed_grid(session_id: int):
saved = await get_answer_grid(session_id)
if not saved:
raise HTTPException(
status_code=400,
detail="Answer grid not confirmed. Please confirm it in step 2 first.",
)
try:
return grid_from_dict(saved)
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid saved grid: {e}")
@app.post("/api/sessions/{session_id}/preview-student-prompt")
async def api_preview_student_prompt(
session_id: int, body: PreviewStudentPromptRequest, user=Depends(get_current_user)
):
"""Preview the markdown prompt that will be sent to the LLM."""
await _session_or_404(session_id, user)
grid = await _require_confirmed_grid(session_id)
if body.student_index < 0 or body.student_index >= len(grid.students):
raise HTTPException(
status_code=400,
detail=f"Invalid student index: {body.student_index}",
)
prompt = build_student_markdown(grid, body.student_index)
return {
"student_name": prompt.student_name,
"total_questions": prompt.total_questions,
"wrong_count": prompt.wrong_count,
"markdown_prompt": prompt.markdown_prompt,
}
@app.post("/api/sessions/{session_id}/generate-student-report")
async def api_generate_student_report(
session_id: int, body: GenerateStudentReportRequest, user=Depends(get_current_user)
):
"""Generate an HTML report for one student from the confirmed grid."""
await _session_or_404(session_id, user)
grid = await _require_confirmed_grid(session_id)
if body.student_index < 0 or body.student_index >= len(grid.students):
raise HTTPException(
status_code=400,
detail=f"Invalid student index: {body.student_index}",
)
try:
html = await generate_student_report(grid, body.student_index, body.model)
report_id = await save_report(session_id, None, html)
student_name = grid.students[body.student_index].name
return {"report_id": report_id, "html_content": html, "student_name": student_name}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Report generation failed: {str(e)}")
@app.get("/api/sessions/{session_id}/reports")
async def api_list_reports(session_id: int, user=Depends(get_current_user)):
session = await get_session(session_id)
if not session or session["user_id"] != user["id"]:
raise HTTPException(status_code=404, detail="Session not found")
reports = await get_reports(session_id)
return {"reports": reports}
@app.get("/api/reports/{report_id}")
async def api_get_report(report_id: int, user=Depends(get_current_user)):
report = await get_report(report_id)
if not report:
raise HTTPException(status_code=404, detail="Report not found")
return report
# =============================================================================
# Health Check
# =============================================================================
@app.get("/api/health")
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "ClassLens", "version": "2.0.0"}
# =============================================================================
# Static File Serving (Production)
# =============================================================================
if STATIC_DIR.exists() and (STATIC_DIR / "index.html").exists():
if (STATIC_DIR / "assets").exists():
app.mount("/assets", StaticFiles(directory=STATIC_DIR / "assets"), name="assets")
@app.get("/favicon.ico")
async def favicon():
favicon_path = STATIC_DIR / "favicon.ico"
if favicon_path.exists():
return FileResponse(favicon_path)
return Response(status_code=404)
@app.get("/")
async def serve_spa():
return FileResponse(STATIC_DIR / "index.html")
@app.get("/{full_path:path}")
async def serve_spa_routes(full_path: str):
if full_path.startswith("api/") or full_path == "health":
return Response(status_code=404)
file_path = STATIC_DIR / full_path
if file_path.exists() and file_path.is_file():
return FileResponse(file_path)
return FileResponse(STATIC_DIR / "index.html")
else:
@app.get("/")
async def root():
return {
"name": "ClassLens API",
"version": "2.0.0",
"mode": "development",
"description": "AI-powered exam analysis for teachers",
}