diff --git "a/main.py" "b/main.py" new file mode 100644--- /dev/null +++ "b/main.py" @@ -0,0 +1,13157 @@ +""" +MathPulse AI - FastAPI Backend +AI-powered math tutoring backend using Hugging Face models. +- meta-llama/Llama-3.1-8B-Instruct for chat, learning paths, insights, and quiz generation + (via Hugging Face Inference API) +- facebook/bart-large-mnli for student risk classification +- Multi-method verification system for math accuracy +- AI-powered Quiz Maker with Bloom's Taxonomy integration +- Symbolic math calculator via SymPy +- Analytics and automation engine modules + +Auto-deployed to HuggingFace Spaces via GitHub Actions. +""" + +import os +import io +import re +import json +import ast +import math +import html +import hashlib +import logging +import traceback +import urllib.parse +import random +import secrets +import string +from contextlib import asynccontextmanager +from typing import List, Optional, Dict, Any, Set, Tuple, Iterator, AsyncIterator, Sequence, cast, cast +from collections import Counter, defaultdict +from threading import Lock + +# STARTUP VALIDATION - Run before anything else to prevent restart loops +try: + from startup_validation import run_all_validations + run_all_validations() # Exits with error if any critical check fails +except ImportError as e: + # If startup_validation module is not found, log warning but continue + # This can happen if the module wasn't properly deployed + print(f"⚠️ Warning: startup_validation module not found: {e}") + print(" Continuing without startup validation checks...") +except BaseException as e: + # Do not crash the container on validation issues unless explicitly configured. + strict_startup_validation = os.getenv("STARTUP_VALIDATION_STRICT", "false").strip().lower() in {"1", "true", "yes", "on"} + if strict_startup_validation: + raise + print(f"⚠️ Warning: startup validation failed but strict mode is disabled: {e}") + print(" Continuing startup to avoid restart-loop crash.") + +from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Request, Form +from fastapi.encoders import jsonable_encoder +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from pydantic import BaseModel, Field, field_validator +from starlette.middleware.base import BaseHTTPMiddleware +import asyncio +import time +import uuid +import sys +from datetime import datetime, timezone, timedelta +import tempfile +import subprocess +import requests as http_requests +import httpx +import uvicorn +from services.inference_client import ( + InferenceRequest, create_default_client, + get_model_for_task, get_current_runtime_config, +) +from services.deterministic_cache import DeterministicResponseCache +from services.logging_utils import log_model_call +from services.email_service import create_email_service_from_env, EmailMessagePayload +from services.user_provisioning_service import ( + AdminCreateUserInput, + CreateUserAndNotifyResult, + UserProvisioningError, + UserProvisioningService, +) +from routes.rag_routes import router as rag_router +from routes.admin_model_routes import router as admin_model_router +from routes.diagnostic import router as diagnostic_router +from rag.curriculum_rag import ( + build_analysis_curriculum_context, + build_lesson_prompt, + build_lesson_query, + retrieve_curriculum_context, + summarize_retrieval_confidence, +) + +try: + import firebase_admin # type: ignore[import-not-found] + from firebase_admin import auth as firebase_auth # type: ignore[import-not-found] + from firebase_admin import firestore as firebase_firestore # type: ignore[import-not-found] + HAS_FIREBASE_ADMIN = True +except Exception: + firebase_admin = None # type: ignore[assignment] + firebase_auth = None # type: ignore[assignment] + firebase_firestore = None # type: ignore[assignment] + HAS_FIREBASE_ADMIN = False + +try: + from google.oauth2 import id_token as google_id_token # type: ignore[import-not-found] + from google.auth.transport import requests as google_auth_requests # type: ignore[import-not-found] + HAS_GOOGLE_AUTH = True +except Exception: + google_id_token = None # type: ignore[assignment] + google_auth_requests = None # type: ignore[assignment] + HAS_GOOGLE_AUTH = False + +# Event-driven automation engine +from automation_engine import ( + automation_engine, + DiagnosticCompletionPayload, + QuizSubmissionPayload, + StudentEnrollmentPayload, + DataImportPayload, + ContentUpdatePayload, + AutomationResult, +) + +# ML-powered analytics module +from analytics import ( + # Request/Response models + CompetencyAnalysisRequest, + CompetencyAnalysisResponse, + TopicRecommendationRequest, + TopicRecommendationResponse, + EnhancedRiskPrediction, + EnhancedRiskRequest, + RiskTrainRequest, + RiskTrainResponse, + CalibrateDifficultyRequest, + CalibrateDifficultyResponse, + AdaptiveQuizRequest as AdaptiveQuizSelectRequest, + StudentSummaryResponse, + ClassInsightsRequest, + ClassInsightsResponse, + MockDataRequest, + RefreshCacheResponse, + # Core functions + compute_competency_analysis, + predict_risk_enhanced, + train_risk_model, + calibrate_question_difficulty, + select_adaptive_quiz, + recommend_topics, + get_student_summary, + get_class_insights, + generate_mock_student_data, + refresh_all_caches, + # Helpers + fetch_student_quiz_history, + fetch_topic_dependencies, + store_competency_analysis, + RISK_MODEL_PATH, + COMPETENCY_THRESHOLDS, + MIN_QUIZ_ATTEMPTS_FOR_COMPETENCY, +) + +# ─── Configuration ───────────────────────────────────────────── + +# ─── Configuration ───────────────────────────────────────────── + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("mathpulse") + +# Lazy-initialized inference client to avoid startup blocking +_inference_client = None +_inference_client_lock = Lock() + +def get_inference_client(): + """Lazy-initialize inference client on first use.""" + global _inference_client + if _inference_client is None: + with _inference_client_lock: + if _inference_client is None: + logger.info("🔧 Initializing InferenceClient...") + firestore_client = None + if HAS_FIREBASE_ADMIN and _firebase_ready: + try: + firestore_client = firebase_firestore.client() + except Exception: + pass + _inference_client = create_default_client(firestore_client=firestore_client) + logger.info("✅ InferenceClient initialized") + return _inference_client + +HF_TOKEN = os.environ.get( + "HF_TOKEN", + os.environ.get("HUGGING_FACE_API_TOKEN", os.environ.get("HUGGINGFACE_API_TOKEN", "")), +) # Kept for HF Space deployment / dataset push only; AI inference uses DEEPSEEK_API_KEY + +# Grade 11-12 tutoring default model. Can be overridden via INFERENCE_MODEL_ID or INFERENCE_CHAT_MODEL_ID. +HF_MATH_MODEL_ID = os.getenv("INFERENCE_CHAT_MODEL_ID") or os.getenv("INFERENCE_MODEL_ID") or os.getenv("HF_MATH_MODEL_ID", "deepseek-chat") + +# Alias kept so automation_engine.py (which imports CHAT_MODEL) keeps working. +CHAT_MODEL = HF_MATH_MODEL_ID + +# Dedicated quiz model override. When empty, routing.task_model_map decides quiz model. +HF_QUIZ_MODEL_ID = (os.getenv("HF_QUIZ_MODEL_ID", "").strip() or None) +HF_QUIZ_JSON_REPAIR_MODEL_ID = os.getenv("HF_QUIZ_JSON_REPAIR_MODEL_ID", "deepseek-chat") + +RISK_MODEL = CHAT_MODEL +VERIFICATION_SAMPLES = 3 # Number of samples for self-consistency checking +ENABLE_DEV_ENDPOINTS = os.getenv("ENABLE_DEV_ENDPOINTS", "false").strip().lower() in {"1", "true", "yes", "on"} +UPLOAD_MAX_BYTES = int(os.getenv("UPLOAD_MAX_BYTES", str(5 * 1024 * 1024))) +UPLOAD_MAX_ROWS = int(os.getenv("UPLOAD_MAX_ROWS", "2000")) +UPLOAD_MAX_COLS = int(os.getenv("UPLOAD_MAX_COLS", "60")) +UPLOAD_MAX_PDF_PAGES = int(os.getenv("UPLOAD_MAX_PDF_PAGES", "20")) +UPLOAD_RATE_LIMIT_PER_MIN = int(os.getenv("UPLOAD_RATE_LIMIT_PER_MIN", "12")) +UPLOAD_MAX_FILES_PER_REQUEST = int(os.getenv("UPLOAD_MAX_FILES_PER_REQUEST", "8")) +ADMIN_USERS_QUERY_TIMEOUT_SECONDS = float(os.getenv("ADMIN_USERS_QUERY_TIMEOUT_SECONDS", "18")) +ADMIN_USERS_MAX_SCAN_DOCS = int(os.getenv("ADMIN_USERS_MAX_SCAN_DOCS", "5000")) +IMPORT_RETENTION_DAYS = int(os.getenv("IMPORT_RETENTION_DAYS", "180")) +ENABLE_IMPORT_GROUNDED_QUIZ = os.getenv("ENABLE_IMPORT_GROUNDED_QUIZ", "true").strip().lower() in {"1", "true", "yes", "on"} +ENABLE_IMPORT_GROUNDED_LESSON = os.getenv("ENABLE_IMPORT_GROUNDED_LESSON", "true").strip().lower() in {"1", "true", "yes", "on"} +ENABLE_IMPORT_GROUNDED_FEEDBACK_EVENTS = os.getenv("ENABLE_IMPORT_GROUNDED_FEEDBACK_EVENTS", "true").strip().lower() in {"1", "true", "yes", "on"} +ENFORCE_LEGIT_SOURCES_FOR_LESSONS = os.getenv("ENFORCE_LEGIT_SOURCES_FOR_LESSONS", "true").strip().lower() in {"1", "true", "yes", "on"} +ENABLE_ASYNC_GENERATION = os.getenv("ENABLE_ASYNC_GENERATION", "true").strip().lower() in {"1", "true", "yes", "on"} +ENABLE_LLM_RISK_RECOMMENDATIONS = os.getenv("ENABLE_LLM_RISK_RECOMMENDATIONS", "true").strip().lower() in {"1", "true", "yes", "on"} +ENABLE_RAG_ANALYSIS_CONTEXT = os.getenv("ENABLE_RAG_ANALYSIS_CONTEXT", "true").strip().lower() in {"1", "true", "yes", "on"} +ASYNC_TASK_TTL_SECONDS = int(os.getenv("ASYNC_TASK_TTL_SECONDS", "3600")) +ASYNC_TASK_MAX_ITEMS = int(os.getenv("ASYNC_TASK_MAX_ITEMS", "400")) +LESSON_SOURCE_MIN_TEXT_LENGTH = int(os.getenv("LESSON_SOURCE_MIN_TEXT_LENGTH", "240")) +LESSON_SOURCE_MIN_TOPICS = int(os.getenv("LESSON_SOURCE_MIN_TOPICS", "1")) +LESSON_VALIDATION_MIN_SCORE = float(os.getenv("LESSON_VALIDATION_MIN_SCORE", "0.7")) +ENABLE_FALLBACK_FIREBASE_TOKEN_VERIFY = os.getenv("ENABLE_FALLBACK_FIREBASE_TOKEN_VERIFY", "true").strip().lower() in {"1", "true", "yes", "on"} +FIREBASE_AUTH_PROJECT_ID = os.getenv("FIREBASE_AUTH_PROJECT_ID", "mathpulse-ai-2026").strip() +FIREBASE_SERVICE_ACCOUNT_JSON = os.getenv("FIREBASE_SERVICE_ACCOUNT_JSON", "").strip() +FIREBASE_SERVICE_ACCOUNT_FILE = os.getenv("FIREBASE_SERVICE_ACCOUNT_FILE", "").strip() +FIREBASE_AUTH_PROJECT_ALLOWLIST: Set[str] = { + value.strip() + for value in os.getenv("FIREBASE_AUTH_PROJECT_ALLOWLIST", "").split(",") + if value.strip() +} +CHAT_MAX_NEW_TOKENS = max(256, int(os.getenv("CHAT_MAX_NEW_TOKENS", "8192"))) +CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC = max(5, int(os.getenv("CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC", "90"))) +CHAT_STREAM_TOTAL_TIMEOUT_SEC = max( + CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC, + int(os.getenv("CHAT_STREAM_TOTAL_TIMEOUT_SEC", "900")), +) +CHAT_STREAM_CONTINUATION_ENABLED = os.getenv( + "CHAT_STREAM_CONTINUATION_ENABLED", + "true", +).strip().lower() in {"1", "true", "yes", "on"} +CHAT_STREAM_CONTINUATION_MAX_ROUNDS = max( + 0, + int(os.getenv("CHAT_STREAM_CONTINUATION_MAX_ROUNDS", "2")), +) +CHAT_STREAM_CONTINUATION_MIN_NEW_CHARS = max( + 1, + int(os.getenv("CHAT_STREAM_CONTINUATION_MIN_NEW_CHARS", "24")), +) +CHAT_STREAM_CONTINUATION_TAIL_CHARS = max( + 80, + int(os.getenv("CHAT_STREAM_CONTINUATION_TAIL_CHARS", "900")), +) +CHAT_STREAM_COMPLETION_MODE_DEFAULT = os.getenv( + "CHAT_STREAM_COMPLETION_MODE_DEFAULT", + "auto", +).strip().lower() +if CHAT_STREAM_COMPLETION_MODE_DEFAULT not in {"auto", "marker", "none"}: + CHAT_STREAM_COMPLETION_MODE_DEFAULT = "auto" + +DETERMINISTIC_CACHE_ENABLED = os.getenv("DETERMINISTIC_CACHE_ENABLED", "true").strip().lower() in {"1", "true", "yes", "on"} +DETERMINISTIC_CACHE_MAX_ENTRIES = max(100, int(os.getenv("DETERMINISTIC_CACHE_MAX_ENTRIES", "1200"))) +DETERMINISTIC_CACHE_REDIS_URL = os.getenv("DETERMINISTIC_CACHE_REDIS_URL", "").strip() or None + +VERIFY_SOLUTION_CACHE_TTL_SECONDS = max(30, int(os.getenv("VERIFY_SOLUTION_CACHE_TTL_SECONDS", "900"))) +PREDICT_RISK_CACHE_TTL_SECONDS = max(30, int(os.getenv("PREDICT_RISK_CACHE_TTL_SECONDS", "600"))) +LEARNING_PATH_CACHE_TTL_SECONDS = max(30, int(os.getenv("LEARNING_PATH_CACHE_TTL_SECONDS", "300"))) +DAILY_INSIGHT_CACHE_TTL_SECONDS = max(30, int(os.getenv("DAILY_INSIGHT_CACHE_TTL_SECONDS", "180"))) + +ALLOWED_UPLOAD_EXTENSIONS: Set[str] = {".csv", ".xlsx", ".xls", ".pdf"} +ALLOWED_UPLOAD_MIME_TYPES: Set[str] = { + "text/csv", + "application/csv", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "application/pdf", + "application/octet-stream", +} +ALLOWED_COURSE_MATERIAL_EXTENSIONS: Set[str] = {".pdf", ".docx", ".txt"} +ALLOWED_COURSE_MATERIAL_MIME_TYPES: Set[str] = { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "text/plain", + "application/octet-stream", +} + +VALID_ROLES: Set[str] = {"student", "teacher", "admin"} +ALL_APP_ROLES: Set[str] = {"student", "teacher", "admin"} +TEACHER_OR_ADMIN: Set[str] = {"teacher", "admin"} +ADMIN_ONLY: Set[str] = {"admin"} + +PUBLIC_PATHS: Set[str] = { + "/", + "/health", + "/docs", + "/redoc", + "/openapi.json", +} +PUBLIC_API_PATHS: Set[str] = { + "/api/quiz/topics", + "/api/rag/health", +} + +ROLE_POLICIES: Dict[str, Set[str]] = { + "/api/chat": ALL_APP_ROLES, + "/api/chat/stream": ALL_APP_ROLES, + "/api/verify-solution": ALL_APP_ROLES, + "/api/predict-risk": TEACHER_OR_ADMIN, + "/api/predict-risk/batch": TEACHER_OR_ADMIN, + "/api/learning-path": ALL_APP_ROLES, + "/api/analytics/daily-insight": TEACHER_OR_ADMIN, + "/api/upload/class-records": TEACHER_OR_ADMIN, + "/api/upload/class-records/risk-refresh/recent": TEACHER_OR_ADMIN, + "/api/import/student-accounts/preview": TEACHER_OR_ADMIN, + "/api/import/student-accounts/commit": TEACHER_OR_ADMIN, + "/api/admin/users": ADMIN_ONLY, + "/api/admin/users/bulk-action": ADMIN_ONLY, + "/api/upload/course-materials": TEACHER_OR_ADMIN, + "/api/upload/course-materials/recent": TEACHER_OR_ADMIN, + "/api/course-materials/topics": TEACHER_OR_ADMIN, + "/api/quiz/generate": TEACHER_OR_ADMIN, + "/api/quiz/generate-async": TEACHER_OR_ADMIN, + "/api/quiz/preview": TEACHER_OR_ADMIN, + "/api/lesson/generate": TEACHER_OR_ADMIN, + "/api/lesson/generate-async": TEACHER_OR_ADMIN, + "/api/rag/lesson": TEACHER_OR_ADMIN, + "/api/rag/generate-problem": TEACHER_OR_ADMIN, + "/api/rag/analysis-context": TEACHER_OR_ADMIN, + "/api/feedback/import-grounded": TEACHER_OR_ADMIN, + "/api/feedback/import-grounded/summary": TEACHER_OR_ADMIN, + "/api/import-grounded/access-audit": TEACHER_OR_ADMIN, + "/api/quiz/student-competency": TEACHER_OR_ADMIN, + "/api/calculator/evaluate": ALL_APP_ROLES, + "/api/diagnostic/generate": ALL_APP_ROLES, + "/api/diagnostic/submit": ALL_APP_ROLES, + "/api/student/competency-analysis": TEACHER_OR_ADMIN, + "/api/risk/train-model": ADMIN_ONLY, + "/api/predict-risk/enhanced": TEACHER_OR_ADMIN, + "/api/quiz/calibrate-difficulty": TEACHER_OR_ADMIN, + "/api/quiz/adaptive-select": TEACHER_OR_ADMIN, + "/api/learning/recommend-topics": TEACHER_OR_ADMIN, + "/api/analytics/student-summary": ALL_APP_ROLES, + "/api/analytics/class-insights": TEACHER_OR_ADMIN, + "/api/analytics/refresh-cache": ADMIN_ONLY, + "/api/testing/reset-data": ALL_APP_ROLES, + "/api/ops/inference-metrics": ADMIN_ONLY, + "/api/hf/monitoring": ADMIN_ONLY, + "/api/dev/generate-mock-data": ADMIN_ONLY, + "/api/analytics/config": TEACHER_OR_ADMIN, + "/api/analytics/imported-class-overview": TEACHER_OR_ADMIN, + "/api/analytics/topic-mastery": TEACHER_OR_ADMIN, + "/api/automation/diagnostic-completed": ADMIN_ONLY, + "/api/automation/quiz-submitted": ADMIN_ONLY, + "/api/automation/student-enrolled": ADMIN_ONLY, + "/api/automation/data-imported": ADMIN_ONLY, + "/api/automation/content-updated": ADMIN_ONLY, + "/api/admin/model-config": ADMIN_ONLY, + "/api/admin/model-config/profile": ADMIN_ONLY, + "/api/admin/model-config/override": ADMIN_ONLY, + "/api/admin/model-config/reset": ADMIN_ONLY, +} + +if not os.getenv("DEEPSEEK_API_KEY"): + logger.warning( + "DEEPSEEK_API_KEY is not set. AI features will fail. " + "Set the DEEPSEEK_API_KEY environment variable." + ) + +deterministic_response_cache = DeterministicResponseCache( + enabled=DETERMINISTIC_CACHE_ENABLED, + max_entries=DETERMINISTIC_CACHE_MAX_ENTRIES, + redis_url=DETERMINISTIC_CACHE_REDIS_URL, + logger=logger, +) + + +def _set_cache_response_header(response: Optional[Response], hit: bool) -> None: + if response is None: + return + response.headers["X-Cache"] = "HIT" if hit else "MISS" + +# ─── FastAPI App ─────────────────────────────────────────────── + + +@asynccontextmanager +async def app_lifespan(_app: FastAPI) -> AsyncIterator[None]: + """Initialize and tear down backend services for app lifespan.""" + logger.info("⚙️ Initializing backend services...") + _init_firebase_admin() + + # Pre-initialize inference client at startup to avoid first-request latency spike + logger.info("🔧 Pre-initializing InferenceClient...") + try: + get_inference_client() + logger.info("✅ InferenceClient pre-initialized at startup") + except Exception as e: + logger.warning(f"⚠️ Failed to pre-initialize InferenceClient: {e}") + + active_model = os.getenv("HF_MODEL_ID", "deepseek-chat") + try: + from rag.vectorstore_loader import get_vectorstore_health + health = get_vectorstore_health() + logger.info( + "RAG vectorstore ready: %d chunks | subjects: %s | model: %s", + health["chunkCount"], + list(health["subjects"].keys()), + active_model, + ) + if health["chunkCount"] == 0: + logger.warning( + "RAG vectorstore is EMPTY. Run: python backend/scripts/ingest_curriculum.py" + ) + if "235B" in active_model: + logger.info( + "Production model active: %s — sequential inference only (--max-num-seqs 1)", + active_model, + ) + except Exception as exc: + logger.error("RAG vectorstore warm-up failed: %s", exc) + + logger.info(f"✅ MathPulse AI backend ready at http://0.0.0.0:7860") + logger.info(f" - INFERENCE_PROVIDER: {os.getenv('INFERENCE_PROVIDER', 'deepseek')}") + logger.info(f" - INFERENCE_MODEL_ID: {os.getenv('INFERENCE_MODEL_ID', HF_MATH_MODEL_ID)}") + logger.info(f" - INFERENCE_CHAT_MODEL_ID: {os.getenv('INFERENCE_CHAT_MODEL_ID', HF_MATH_MODEL_ID)}") + logger.info( + f" - INFERENCE_CHAT_STRICT_MODEL_ONLY: " + f"{os.getenv('INFERENCE_CHAT_STRICT_MODEL_ONLY', 'true')}" + ) + logger.info( + f" - INFERENCE_ENFORCE_LOCK_MODEL: " + f"{os.getenv('INFERENCE_ENFORCE_LOCK_MODEL', 'true')}" + ) + logger.info(f" - DEEPSEEK_API_KEY set: {'yes' if os.getenv('DEEPSEEK_API_KEY') else 'no'}") + + try: + yield + finally: + await _close_hf_async_http_client() + + +app = FastAPI( + title="MathPulse AI API", + description="AI-powered math tutoring and student analytics backend", + version="1.0.0", + lifespan=app_lifespan, +) + +logger.info("🚀 FastAPI app created, startup sequence beginning...") + + +class AuthenticatedUser(BaseModel): + uid: str + email: Optional[str] = None + role: str + claims: Dict[str, Any] = Field(default_factory=dict) + + +_firebase_ready = False +_role_cache: Dict[str, Dict[str, Any]] = {} +_ROLE_CACHE_TTL_SECONDS = 60 +_firestore_role_lookup_warning_emitted = False +_firestore_audit_lookup_warning_emitted = False +_firestore_topics_lookup_warning_emitted = False +_rate_limit_buckets: Dict[str, List[float]] = {} +_async_tasks: Dict[str, Dict[str, Any]] = {} +_async_tasks_lock = Lock() +_account_import_previews: Dict[str, Dict[str, Any]] = {} +_account_import_previews_lock = Lock() +ACCOUNT_IMPORT_PREVIEW_TTL_SECONDS = int(os.getenv("ACCOUNT_IMPORT_PREVIEW_TTL_SECONDS", "1200")) +FIRESTORE_SERVER_TIMESTAMP: Any = getattr(cast(Any, firebase_firestore), "SERVER_TIMESTAMP", None) +FIRESTORE_QUERY_DESCENDING: Any = getattr(getattr(cast(Any, firebase_firestore), "Query", None), "DESCENDING", "DESCENDING") + + +def _snapshot_to_dict(snapshot: Any) -> Dict[str, Any]: + data = snapshot.to_dict() if hasattr(snapshot, "to_dict") else {} + return data if isinstance(data, dict) else {} + + +def _snapshot_exists(snapshot: Any) -> bool: + return bool(getattr(snapshot, "exists", False)) + + +def _is_adc_missing_error(err: Exception) -> bool: + message = str(err).lower() + return ( + "default credentials were not found" in message + or "application default credentials" in message + ) + + +def _is_auth_user_not_found_error(err: Exception) -> bool: + message = str(err).lower() + return ( + "no user record" in message + or "user-not-found" in message + or "not found" in message + ) + + +def _init_firebase_admin() -> None: + global _firebase_ready + if _firebase_ready: + return + if not HAS_FIREBASE_ADMIN: + logger.warning("firebase-admin is not available; protected API endpoints will reject requests.") + return + + try: + if not firebase_admin._apps: # type: ignore[attr-defined] + init_options: Dict[str, Any] = {} + credentials_obj: Optional[Any] = None + if FIREBASE_AUTH_PROJECT_ID: + init_options["projectId"] = FIREBASE_AUTH_PROJECT_ID + + if FIREBASE_SERVICE_ACCOUNT_JSON: + service_account_payload = json.loads(FIREBASE_SERVICE_ACCOUNT_JSON) + credentials_obj = cast(Any, firebase_admin).credentials.Certificate(service_account_payload) + elif FIREBASE_SERVICE_ACCOUNT_FILE: + credentials_obj = cast(Any, firebase_admin).credentials.Certificate(FIREBASE_SERVICE_ACCOUNT_FILE) + + if credentials_obj and init_options: + firebase_admin.initialize_app(credentials_obj, options=init_options) # type: ignore[union-attr] + elif credentials_obj: + firebase_admin.initialize_app(credentials_obj) # type: ignore[union-attr] + elif init_options: + firebase_admin.initialize_app(options=init_options) # type: ignore[union-attr] + else: + firebase_admin.initialize_app() # type: ignore[union-attr] + _firebase_ready = True + if FIREBASE_AUTH_PROJECT_ID: + logger.info(f"Firebase Admin SDK initialized for API auth verification (projectId={FIREBASE_AUTH_PROJECT_ID})") + else: + logger.info("Firebase Admin SDK initialized for API auth verification") + except Exception as e: + logger.warning( + "Firebase Admin SDK init failed: %s. Configure FIREBASE_SERVICE_ACCOUNT_JSON " + "or FIREBASE_SERVICE_ACCOUNT_FILE, or set GOOGLE_APPLICATION_CREDENTIALS.", + e, + ) + + +def _get_role_from_firestore(uid: str) -> Optional[str]: + now = time.time() + cached = _role_cache.get(uid) + if cached and now - float(cached.get("ts", 0)) < _ROLE_CACHE_TTL_SECONDS: + return cached.get("role") + + if not (_firebase_ready and firebase_firestore): + return None + + try: + doc = cast(Any, firebase_firestore.client().collection("users").document(uid).get()) + role = _snapshot_to_dict(doc).get("role") if _snapshot_exists(doc) else None + if isinstance(role, str): + _role_cache[uid] = {"role": role, "ts": now} + return role + except Exception as e: + _warn_firestore_role_lookup_once(f"Failed to resolve role from Firestore for {uid}: {e}") + + return None + + +def _warn_firestore_role_lookup_once(message: str) -> None: + global _firestore_role_lookup_warning_emitted + if _firestore_role_lookup_warning_emitted: + return + logger.warning(message) + _firestore_role_lookup_warning_emitted = True + + +def _warn_firestore_audit_lookup_once(message: str) -> None: + global _firestore_audit_lookup_warning_emitted + if _firestore_audit_lookup_warning_emitted: + return + logger.warning(message) + _firestore_audit_lookup_warning_emitted = True + + +def _warn_firestore_topics_lookup_once(message: str) -> None: + global _firestore_topics_lookup_warning_emitted + if _firestore_topics_lookup_warning_emitted: + return + logger.warning(message) + _firestore_topics_lookup_warning_emitted = True + + +def _extract_role_from_firestore_rest_payload(payload: Dict[str, Any]) -> Optional[str]: + fields = payload.get("fields") + if not isinstance(fields, dict): + return None + role_field = fields.get("role") + if not isinstance(role_field, dict): + return None + + # Firestore REST encodes field values by type, e.g. {"stringValue": "teacher"}. + role_value = role_field.get("stringValue") + if isinstance(role_value, str): + normalized = role_value.strip().lower() + if normalized in VALID_ROLES: + return normalized + return None + + +def _get_role_from_firestore_rest(uid: str, firebase_id_token: str, project_id: str) -> Optional[str]: + if not uid or not firebase_id_token or not project_id: + return None + + role_doc_url = ( + "https://firestore.googleapis.com/v1/" + f"projects/{project_id}/databases/(default)/documents/users/{uid}" + "?mask.fieldPaths=role" + ) + + try: + response = http_requests.get( + role_doc_url, + headers={"Authorization": f"Bearer {firebase_id_token}"}, + timeout=5, + ) + if response.status_code == 200: + payload = cast(Dict[str, Any], response.json()) + return _extract_role_from_firestore_rest_payload(payload) + if response.status_code not in {401, 403, 404}: + logger.info( + "Firestore REST role lookup returned non-success status %s for uid=%s", + response.status_code, + uid, + ) + except Exception as rest_error: + logger.info(f"Firestore REST role lookup failed for {uid}: {rest_error}") + + return None + + +def _resolve_user_role(decoded: Dict[str, Any], firebase_id_token: Optional[str] = None) -> str: + role_claim = decoded.get("role") + if isinstance(role_claim, str) and role_claim in VALID_ROLES: + return role_claim + + uid = _extract_uid_from_claims(decoded) + if uid: + firestore_role = _get_role_from_firestore(uid) + if isinstance(firestore_role, str) and firestore_role in VALID_ROLES: + return firestore_role + + if firebase_id_token: + token_project_id = str(decoded.get("aud", "")).strip() + project_id = FIREBASE_AUTH_PROJECT_ID or token_project_id + firestore_rest_role = _get_role_from_firestore_rest(uid, firebase_id_token, project_id) + if isinstance(firestore_rest_role, str) and firestore_rest_role in VALID_ROLES: + _role_cache[uid] = {"role": firestore_rest_role, "ts": time.time()} + return firestore_rest_role + + return "student" + + +def _extract_uid_from_claims(decoded: Dict[str, Any]) -> str: + """Normalize Firebase token user identifier across verifier implementations.""" + for key in ("uid", "user_id", "sub"): + value = decoded.get(key) + if isinstance(value, str): + normalized = value.strip() + if normalized: + return normalized + return "" + + +def _parse_bearer_token(authorization: str) -> Optional[str]: + if not authorization: + return None + parts = authorization.strip().split(" ") + if len(parts) != 2 or parts[0].lower() != "bearer": + return None + return parts[1].strip() + + +def _verify_token_with_fallback(token: str) -> Dict[str, Any]: + """Verify Firebase ID token via firebase-admin, then google-auth as fallback.""" + last_error: Optional[Exception] = None + + try: + return cast(Dict[str, Any], firebase_auth.verify_id_token(token)) # type: ignore[union-attr] + except Exception as err: + last_error = err + + if not ENABLE_FALLBACK_FIREBASE_TOKEN_VERIFY: + raise cast(Exception, last_error) + if not HAS_GOOGLE_AUTH: + raise cast(Exception, last_error) + + try: + request_adapter = google_auth_requests.Request() # type: ignore[union-attr] + decoded_raw = google_id_token.verify_firebase_token(token, request_adapter) # type: ignore[union-attr] + decoded = cast(Dict[str, Any], decoded_raw or {}) + if not decoded: + raise ValueError("Fallback Firebase token verification returned empty claims") + + audience = str(decoded.get("aud", "")) + issuer = str(decoded.get("iss", "")) + if audience and issuer != f"https://securetoken.google.com/{audience}": + raise ValueError("Fallback Firebase token verification issuer mismatch") + + if FIREBASE_AUTH_PROJECT_ALLOWLIST and audience not in FIREBASE_AUTH_PROJECT_ALLOWLIST: + raise ValueError("Firebase token project is not in FIREBASE_AUTH_PROJECT_ALLOWLIST") + + logger.info("Firebase token verified via google-auth fallback") + return decoded + except Exception as fallback_err: + logger.warning(f"Fallback Firebase token verification failed: {fallback_err}") + raise cast(Exception, last_error) + + +def get_current_user(request: Request) -> AuthenticatedUser: + user = getattr(request.state, "user", None) + if user is None: + raise HTTPException(status_code=401, detail="Authentication required") + return user + + +def require_student_self_or_staff(request: Request, student_id: str) -> AuthenticatedUser: + user = get_current_user(request) + if user.role in {"teacher", "admin"}: + return user + if user.role == "student" and user.uid == student_id: + return user + raise HTTPException(status_code=403, detail="Insufficient permissions for requested student") + + +def enforce_rate_limit(request: Request, bucket_name: str, limit: int, window_seconds: int) -> None: + user = getattr(request.state, "user", None) + actor_id = user.uid if user else ((request.client.host if request.client else "unknown")) + key = f"{bucket_name}:{actor_id}" + now = time.time() + start = now - window_seconds + hits = [ts for ts in _rate_limit_buckets.get(key, []) if ts >= start] + if len(hits) >= limit: + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded for {bucket_name}. Try again later.", + ) + hits.append(now) + _rate_limit_buckets[key] = hits + + +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _prune_async_tasks(now_ts: Optional[float] = None) -> None: + now_value = now_ts if now_ts is not None else time.time() + expiry_cutoff = now_value - ASYNC_TASK_TTL_SECONDS + expired_task_ids: List[str] = [] + + for task_id, task in list(_async_tasks.items()): + created_iso = str(task.get("createdAt") or "").strip() + if not created_iso: + continue + try: + created_dt = datetime.fromisoformat(created_iso.replace("Z", "+00:00")) + if created_dt.timestamp() < expiry_cutoff: + expired_task_ids.append(task_id) + except ValueError: + continue + + for task_id in expired_task_ids: + _async_tasks.pop(task_id, None) + + if len(_async_tasks) <= ASYNC_TASK_MAX_ITEMS: + return + + overflow = len(_async_tasks) - ASYNC_TASK_MAX_ITEMS + sorted_items = sorted(_async_tasks.items(), key=lambda kv: str(kv[1].get("createdAt") or "")) + for task_id, _ in sorted_items[:overflow]: + _async_tasks.pop(task_id, None) + + +def _create_async_task(owner_uid: str, task_kind: str, payload: Dict[str, Any]) -> str: + task_id = f"task_{uuid.uuid4().hex[:12]}" + with _async_tasks_lock: + _prune_async_tasks() + _async_tasks[task_id] = { + "taskId": task_id, + "taskKind": task_kind, + "ownerUid": owner_uid, + "status": "queued", + "createdAt": _utc_now_iso(), + "startedAt": None, + "completedAt": None, + "progressPercent": 5.0, + "progressStage": "queued", + "progressMessage": "Task queued for background generation.", + "cancelRequested": False, + "request": payload, + "result": None, + "error": None, + } + return task_id + + +def _update_async_task(task_id: str, **updates: Any) -> None: + with _async_tasks_lock: + task = _async_tasks.get(task_id) + if not task: + return + task.update(updates) + + +async def _run_async_task(task_id: str, runner) -> None: + with _async_tasks_lock: + task = _async_tasks.get(task_id) + if not task: + return + if bool(task.get("cancelRequested")): + task["status"] = "cancelled" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "cancelled" + task["progressMessage"] = "Task was cancelled before execution." + task["error"] = {"message": "Task cancelled before execution."} + return + task["status"] = "running" + task["startedAt"] = _utc_now_iso() + task["progressPercent"] = 15.0 + task["progressStage"] = "running" + task["progressMessage"] = "Background generation started." + task["error"] = None + + try: + payload = await runner() + with _async_tasks_lock: + task = _async_tasks.get(task_id) + if not task: + return + if bool(task.get("cancelRequested")): + task["status"] = "cancelled" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "cancelled" + task["progressMessage"] = "Task was cancelled during execution." + task["error"] = {"message": "Task cancelled during execution."} + task["result"] = None + return + task["status"] = "completed" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "completed" + task["progressMessage"] = "Generation completed successfully." + task["result"] = jsonable_encoder(payload) + task["error"] = None + except HTTPException as http_exc: + with _async_tasks_lock: + task = _async_tasks.get(task_id) + if not task: + return + if bool(task.get("cancelRequested")): + task["status"] = "cancelled" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "cancelled" + task["progressMessage"] = "Task was cancelled during execution." + task["error"] = {"message": "Task cancelled during execution."} + task["result"] = None + return + task["status"] = "failed" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "failed" + task["progressMessage"] = "Generation failed while processing the request." + task["error"] = jsonable_encoder(http_exc.detail) + except Exception as exc: + logger.error(f"Async task {task_id} failed: {exc}") + with _async_tasks_lock: + task = _async_tasks.get(task_id) + if not task: + return + if bool(task.get("cancelRequested")): + task["status"] = "cancelled" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "cancelled" + task["progressMessage"] = "Task was cancelled during execution." + task["error"] = {"message": "Task cancelled during execution."} + task["result"] = None + return + task["status"] = "failed" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "failed" + task["progressMessage"] = "Generation failed due to an unexpected error." + task["error"] = {"message": str(exc)} + + +def _start_async_task_in_thread(task_id: str, runner) -> None: + """Run long async generation tasks off the main event loop.""" + asyncio.create_task(asyncio.to_thread(lambda: asyncio.run(_run_async_task(task_id, runner)))) + + +class AuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + request.state.user = None + + if request.method == "OPTIONS" or path in PUBLIC_PATHS: + return await call_next(request) + + if path in PUBLIC_API_PATHS: + return await call_next(request) + + if not path.startswith("/api/"): + return await call_next(request) + + if path == "/api/dev/generate-mock-data" and not ENABLE_DEV_ENDPOINTS: + return JSONResponse(status_code=404, content={"detail": "Not found"}) + + _init_firebase_admin() + if not _firebase_ready: + return JSONResponse( + status_code=503, + content={"detail": "Authentication service unavailable"}, + ) + + token = _parse_bearer_token(request.headers.get("Authorization", "")) + if not token: + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid Authorization bearer token"}, + ) + + try: + decoded = _verify_token_with_fallback(token) + except Exception as e: + logger.warning(f"Token verification failed for {path}: {e}") + return JSONResponse(status_code=401, content={"detail": "Invalid or expired auth token"}) + + uid = _extract_uid_from_claims(decoded) + if not uid: + return JSONResponse(status_code=401, content={"detail": "Token missing uid"}) + + role = _resolve_user_role(decoded, token) + request.state.user = AuthenticatedUser( + uid=uid, + email=decoded.get("email"), + role=role, + claims=decoded, + ) + + required_roles = ROLE_POLICIES.get(path) + if required_roles and role not in required_roles: + return JSONResponse(status_code=403, content={"detail": "Forbidden for this role"}) + + return await call_next(request) + + +# ─── Middleware: Request ID + Logging + Timeout ──────────────── + +REQUEST_TIMEOUT_SECONDS = 120 # 2 minutes for AI-heavy endpoints + + +class RequestMiddleware(BaseHTTPMiddleware): + """Adds request-ID header, logs requests, and enforces timeouts.""" + + async def dispatch(self, request: Request, call_next) -> Response: + request_id = str(uuid.uuid4())[:8] + start = time.time() + + # Attach request_id for downstream logging + request.state.request_id = request_id + logger.info(f"[{request_id}] {request.method} {request.url.path}") + + try: + response = await asyncio.wait_for( + call_next(request), + timeout=REQUEST_TIMEOUT_SECONDS, + ) + duration = round(time.time() - start, 3) + response.headers["X-Request-ID"] = request_id + response.headers["X-Response-Time"] = f"{duration}s" + logger.info(f"[{request_id}] {response.status_code} in {duration}s") + return response + except asyncio.TimeoutError: + duration = round(time.time() - start, 3) + logger.error(f"[{request_id}] TIMEOUT after {duration}s on {request.url.path}") + return JSONResponse( + status_code=504, + content={ + "detail": f"Request timed out after {REQUEST_TIMEOUT_SECONDS}s", + "requestId": request_id, + }, + headers={"X-Request-ID": request_id}, + ) + except Exception as exc: + duration = round(time.time() - start, 3) + logger.error(f"[{request_id}] Unhandled error after {duration}s: {exc}") + return JSONResponse( + status_code=500, + content={ + "detail": "Internal server error", + "requestId": request_id, + }, + headers={"X-Request-ID": request_id}, + ) + + +app.add_middleware(RequestMiddleware) +app.add_middleware(AuthMiddleware) +app.include_router(rag_router) +app.include_router(admin_model_router) +app.include_router(diagnostic_router) + + +# ─── Global Exception Handler ───────────────────────────────── + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + request_id = getattr(request.state, "request_id", "unknown") + logger.error(f"[{request_id}] HTTPException {exc.status_code}: {exc.detail}") + return JSONResponse( + status_code=exc.status_code, + content={ + "detail": exc.detail, + "status": exc.status_code, + "requestId": request_id, + }, + headers={"X-Request-ID": request_id}, + ) + + +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + request_id = getattr(request.state, "request_id", "unknown") + logger.error(f"[{request_id}] Unhandled: {type(exc).__name__}: {exc}\n{traceback.format_exc()}") + return JSONResponse( + status_code=500, + content={ + "detail": "An unexpected error occurred. Please try again.", + "error": type(exc).__name__, + "requestId": request_id, + }, + headers={"X-Request-ID": request_id}, + ) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ─── DeepSeek AI Clients ────────────────────────────────────── + +# Zero-shot classification replaced with DeepSeek chat-based classification. +# BART risk model replaced with deepseek-chat structured output. + +from services.ai_client import get_deepseek_client, CHAT_MODEL, REASONER_MODEL, APIError, RateLimitError, APITimeoutError + +_zsc_client_initialized = False + + +def _ensure_deepseek_available() -> None: + """Verify DeepSeek API key is configured.""" + global _zsc_client_initialized + if not _zsc_client_initialized: + try: + get_deepseek_client() + logger.info("DeepSeek client initialized (for all AI tasks)") + _zsc_client_initialized = True + except ValueError: + raise HTTPException( + status_code=500, + detail="DEEPSEEK_API_KEY not configured. Set the DEEPSEEK_API_KEY environment variable.", + ) + + +# ─── HF Serverless Chat Helper (requests-based) ─────────────── + + +def _strip_repetition(text: str, min_chunk: int = 40) -> str: + """Remove repeated blocks from model output (a common issue with smaller LLMs).""" + lines = text.split("\n") + seen_blocks: list[str] = [] + result_lines: list[str] = [] + i = 0 + while i < len(lines): + # Try to match a block of 2-4 lines that repeats + matched = False + for blen in (4, 3, 2): + if i + blen > len(lines): + continue + block = "\n".join(lines[i : i + blen]).strip() + if len(block) < min_chunk: + continue + if block in seen_blocks: + # Skip this repeated block + i += blen + matched = True + break + seen_blocks.append(block) + if not matched: + result_lines.append(lines[i]) + i += 1 + return "\n".join(result_lines).strip() + + +def _build_hf_inference_url(model_id: str) -> str: + return f"https://api.deepseek.com" + + +def _messages_to_inference_prompt(messages: List[Dict[str, str]]) -> str: + parts: List[str] = [] + for msg in messages: + role = (msg.get("role") or "user").strip().lower() + content = (msg.get("content") or "").strip() + if not content: + continue + if role in {"tool", "function"}: + continue + if role == "system": + parts.append(f"SYSTEM:\n{content}") + elif role == "assistant": + parts.append(f"ASSISTANT:\n{content}") + else: + parts.append(f"USER:\n{content}") + + parts.append("ASSISTANT:") + return "\n\n".join(parts) + + +def call_hf_chat( + messages: List[Dict[str, str]], + *, + max_tokens: int = 2048, + temperature: float = 0.2, + top_p: float = 0.9, + repetition_penalty: float = 1.15, + model: Optional[str] = None, + task_type: str = "default", + timeout: Optional[int] = None, +) -> str: + req = InferenceRequest( + messages=messages, + model=model, + task_type=task_type, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + timeout_sec=timeout, + ) + text = get_inference_client().generate_from_messages(req) + return _strip_repetition(text) + + +def call_hf_chat_stream( + messages: List[Dict[str, str]], + *, + max_tokens: int = 512, + temperature: float = 0.3, + top_p: float = 0.85, + model: Optional[str] = None, + task_type: str = "chat", + timeout: Optional[int] = None, +) -> Iterator[str]: + """Stream chat deltas from DeepSeek API as text chunks.""" + client = get_inference_client() + effective_task = (task_type or "chat").strip().lower() + + selection_req = InferenceRequest( + messages=messages, + model=model, + task_type=task_type, + max_new_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + timeout_sec=timeout, + ) + selected_model, _ = client._resolve_primary_model(selection_req) + + model_chain = client._model_chain_for_task(effective_task, selected_model) + timeout_sec = timeout or client.interactive_timeout_sec + last_error: Optional[Exception] = None + + ds_client = get_deepseek_client() + + for fallback_depth, model_name in enumerate(model_chain): + start = time.perf_counter() + try: + stream = ds_client.chat.completions.create( + model=model_name, + messages=messages, # type: ignore[arg-type] + stream=True, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + timeout=timeout_sec, + ) + + emitted_any = False + for chunk in stream: + for choice in chunk.choices: # type: ignore[union-attr] + delta = getattr(choice, 'delta', None) + if delta and delta.content: + emitted_any = True + yield delta.content + + if emitted_any: + latency_ms = (time.perf_counter() - start) * 1000 + logger.info( + "✅ DeepSeek stream success: task=%s model=%s latency=%sms", + effective_task, + model_name, + round(latency_ms, 0), + ) + return + raise RuntimeError("Stream ended without content") + + except Exception as exc: + last_error = exc + logger.warning( + "⚠️ Stream attempt failed: task=%s model=%s depth=%s error=%s", + effective_task, + model_name, + fallback_depth, + str(exc)[:180], + ) + + raise last_error or RuntimeError("Streaming failed with empty model chain") + + +HF_BLOCKING_CALL_CONCURRENCY = max(1, int(os.getenv("HF_BLOCKING_CALL_CONCURRENCY", "16"))) +_hf_call_semaphore: Optional[asyncio.Semaphore] = None +_hf_call_semaphore_loop: Optional[asyncio.AbstractEventLoop] = None +HF_ASYNC_MAX_CONNECTIONS = max(4, int(os.getenv("HF_ASYNC_MAX_CONNECTIONS", "64"))) +HF_ASYNC_MAX_KEEPALIVE_CONNECTIONS = max( + 2, + int(os.getenv("HF_ASYNC_MAX_KEEPALIVE_CONNECTIONS", "32")), +) +HF_ASYNC_CONNECT_TIMEOUT_SEC = float(os.getenv("HF_ASYNC_CONNECT_TIMEOUT_SEC", "10.0")) +HF_ASYNC_WRITE_TIMEOUT_SEC = float(os.getenv("HF_ASYNC_WRITE_TIMEOUT_SEC", "30.0")) +HF_ASYNC_POOL_TIMEOUT_SEC = float(os.getenv("HF_ASYNC_POOL_TIMEOUT_SEC", "10.0")) +_hf_async_http_client: Optional[httpx.AsyncClient] = None +_hf_async_http_client_lock: Optional[asyncio.Lock] = None +_hf_async_http_client_loop: Optional[asyncio.AbstractEventLoop] = None + + +def _get_hf_async_http_client_lock() -> asyncio.Lock: + global _hf_async_http_client, _hf_async_http_client_lock, _hf_async_http_client_loop + loop = asyncio.get_running_loop() + if _hf_async_http_client_lock is None or _hf_async_http_client_loop is not loop: + # Test environments can spin up multiple event loops; reset pooled client per loop. + _hf_async_http_client = None + _hf_async_http_client_lock = asyncio.Lock() + _hf_async_http_client_loop = loop + return _hf_async_http_client_lock + + +async def _get_hf_async_http_client() -> httpx.AsyncClient: + global _hf_async_http_client + lock = _get_hf_async_http_client_lock() + async with lock: + if _hf_async_http_client is None or _hf_async_http_client.is_closed: + limits = httpx.Limits( + max_connections=HF_ASYNC_MAX_CONNECTIONS, + max_keepalive_connections=HF_ASYNC_MAX_KEEPALIVE_CONNECTIONS, + ) + _hf_async_http_client = httpx.AsyncClient(http2=True, limits=limits) + assert _hf_async_http_client is not None + return _hf_async_http_client + + +async def _close_hf_async_http_client() -> None: + global _hf_async_http_client + lock = _get_hf_async_http_client_lock() + async with lock: + if _hf_async_http_client is not None: + await _hf_async_http_client.aclose() + _hf_async_http_client = None + + +def _get_hf_call_semaphore() -> asyncio.Semaphore: + global _hf_call_semaphore, _hf_call_semaphore_loop + loop = asyncio.get_running_loop() + if _hf_call_semaphore is None or _hf_call_semaphore_loop is not loop: + _hf_call_semaphore = asyncio.Semaphore(HF_BLOCKING_CALL_CONCURRENCY) + _hf_call_semaphore_loop = loop + return _hf_call_semaphore + + +async def _run_hf_blocking(func, /, *args, **kwargs): + semaphore = _get_hf_call_semaphore() + async with semaphore: + return await asyncio.to_thread(func, *args, **kwargs) + + +def _hf_retry_sleep_seconds(backoff_sec: float, attempt: int) -> float: + jitter_factor = random.uniform(0.9, 1.2) + return backoff_sec * attempt * jitter_factor + + +def _resolve_async_hf_timeout(timeout_sec: int) -> httpx.Timeout: + return httpx.Timeout( + connect=HF_ASYNC_CONNECT_TIMEOUT_SEC, + read=float(timeout_sec), + write=HF_ASYNC_WRITE_TIMEOUT_SEC, + pool=HF_ASYNC_POOL_TIMEOUT_SEC, + ) + + +async def call_hf_chat_async( + messages: List[Dict[str, str]], + *, + max_tokens: int = 2048, + temperature: float = 0.2, + top_p: float = 0.9, + repetition_penalty: float = 1.15, + model: Optional[str] = None, + task_type: str = "default", + timeout: Optional[int] = None, +) -> str: + """Async wrapper for DeepSeek chat completions.""" + return await _run_hf_blocking( + call_hf_chat, + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + model=model, + task_type=task_type, + timeout=timeout, + ) + + +async def call_hf_chat_stream_async( + messages: List[Dict[str, str]], + *, + max_tokens: int = 512, + temperature: float = 0.3, + top_p: float = 0.85, + model: Optional[str] = None, + task_type: str = "chat", + timeout: Optional[int] = None, +) -> AsyncIterator[str]: + """Async streaming wrapper for DeepSeek chat completions.""" + stream_iter = call_hf_chat_stream( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + model=model, + task_type=task_type, + timeout=timeout, + ) + done = object() + + def _next_chunk(): + return next(stream_iter, done) + + while True: + chunk = await _run_hf_blocking(_next_chunk) + if chunk is done: + return + if chunk: + yield str(chunk) + + +def load_local_math_model(model_name: str = "deepseek-chat"): + """Optional local loader — deprecated in favor of DeepSeek API.""" + raise NotImplementedError( + "Local model loading is deprecated. Use DeepSeek API via DEEPSEEK_API_KEY env var." + ) + + +# ─── Math Tutor Prompt & Wrapper ────────────────────────────── + +_GREETING_PATTERN = re.compile(r"^\s*(?:hi|hello|hey|good\s+(?:morning|afternoon|evening))\b") +_THANKS_PATTERN = re.compile(r"\b(?:thanks|thank\s+you|thank\s+u|ty)\b") + +_GREETING_RESPONSES: Tuple[str, ...] = ( + "Hi! I am MathPulse, your math tutor. I can help with algebra, geometry, calculus, and more. What math question would you like to try?", + "Hello! Great to see you. I am here for math topics and step-by-step solutions whenever you are ready.", +) + +_THANKS_RESPONSES: Tuple[str, ...] = ( + "You are very welcome. If you want, send another math question and we can work through it together.", + "Glad I could help. I am here anytime you want to practice more math.", +) + +_NON_MATH_REDIRECT_RESPONSES: Tuple[str, ...] = ( + "That topic is outside my math scope, but I would be happy to help with mathematics like algebra, calculus, geometry, trigonometry, or statistics.", + "I focus on math-only support, so I may not be the best for that request. Share a math question and I will guide you step by step.", + "I am built for math tutoring, so I can best help with mathematical problems and explanations. If you want, ask me any math question next.", +) + +_MATH_SCOPE_KEYWORDS: Set[str] = { + "math", + "mathematics", + "algebra", + "geometry", + "trigonometry", + "calculus", + "statistics", + "probability", + "arithmetic", + "equation", + "inequality", + "function", + "graph", + "slope", + "derivative", + "integral", + "limit", + "matrix", + "determinant", + "fraction", + "percentage", + "ratio", + "polynomial", + "quadratic", + "logarithm", + "exponent", + "angle", + "triangle", + "circle", + "perimeter", + "area", + "volume", + "mean", + "median", + "mode", + "standard deviation", + "solve", + "simplify", + "factor", + "evaluate", + "compute", + "calculate", +} + +_MATH_SCOPE_PATTERNS: Tuple[re.Pattern[str], ...] = ( + re.compile(r"\d+\s*[%+\-*/^=]\s*[-+]?\d*"), + re.compile(r"\b(?:sin|cos|tan|cot|sec|csc|log|ln|sqrt)\s*\(?"), + re.compile(r"\b(?:differentiate|integrate|derive|proof|prove)\b"), + re.compile(r"\b(?:x|y|z)\s*[=+\-*/^]\s*[-+]?\d"), +) + +_CONTINUATION_FOLLOWUP_TOKENS: Set[str] = { + "go", + "continue", + "yes", + "ok", + "next", + "more", +} + +_CONTINUATION_INVITE_PATTERNS: Tuple[re.Pattern[str], ...] = ( + re.compile(r"\bshall\s+we\s+continue\b", re.IGNORECASE), + re.compile(r"\b(?:would|do)\s+you\s+like\s+to\s+continue\b", re.IGNORECASE), + re.compile(r"\b(?:want|need)\s+me\s+to\s+continue\b", re.IGNORECASE), + re.compile(r"\bshould\s+(?:i|we)\s+continue\b", re.IGNORECASE), + re.compile(r"\bcontinue\s*\?\s*$", re.IGNORECASE), + re.compile(r"\b(?:ready\s+for|go\s+to)\s+the\s+next\s+step\b", re.IGNORECASE), + re.compile(r"\bnext\s+step(?:s)?\s*\?\s*$", re.IGNORECASE), + re.compile(r"\bkeep\s+going\s*\?\s*$", re.IGNORECASE), +) + +_CONTINUATION_CONTEXT_CLARIFY_RESPONSE = ( + "I can continue once I know which math problem you mean. " + "Please share the problem again or tell me which step to continue." +) + + +def is_math_related_query(message: str) -> bool: + normalized = (message or "").strip().lower() + if not normalized: + return False + + if any(keyword in normalized for keyword in _MATH_SCOPE_KEYWORDS): + return True + + return any(pattern.search(normalized) for pattern in _MATH_SCOPE_PATTERNS) + + +def _normalize_continuation_followup_token(message: str) -> str: + normalized = re.sub(r"\s+", " ", (message or "").strip().lower()) + return normalized.rstrip(".!?,\u2026") + + +def _is_continuation_followup_token(message: str) -> bool: + followup_token = _normalize_continuation_followup_token(message) + return followup_token in _CONTINUATION_FOLLOWUP_TOKENS + + +def _extract_latest_assistant_message(history: Optional[Sequence[Any]]) -> Optional[str]: + if not history: + return None + + for entry in reversed(history): + role: Optional[Any] = None + content: Optional[Any] = None + + if isinstance(entry, dict): + role = entry.get("role") + content = entry.get("content") + else: + role = getattr(entry, "role", None) + content = getattr(entry, "content", None) + + role_text = str(role or "").strip().lower() + content_text = str(content or "").strip() + if not content_text: + continue + + if role_text in {"assistant", "ai"}: + return content_text + + return None + + +def _extract_latest_user_intent_message(history: Optional[Sequence[Any]]) -> Optional[str]: + if not history: + return None + + for entry in reversed(history): + role: Optional[Any] = None + content: Optional[Any] = None + + if isinstance(entry, dict): + role = entry.get("role") + content = entry.get("content") + else: + role = getattr(entry, "role", None) + content = getattr(entry, "content", None) + + role_text = str(role or "").strip().lower() + content_text = str(content or "").strip() + if not content_text: + continue + + if role_text != "user": + continue + + if _is_continuation_followup_token(content_text): + continue + + return content_text + + return None + + +def _is_contextual_continuation_followup(message: str, history: Optional[Sequence[Any]]) -> bool: + if not _is_continuation_followup_token(message): + return False + + latest_assistant_message = _extract_latest_assistant_message(history) + if not latest_assistant_message: + return False + + return any(pattern.search(latest_assistant_message) for pattern in _CONTINUATION_INVITE_PATTERNS) + + +def _scope_boundary_response_without_continuation(message: str) -> Optional[str]: + normalized = (message or "").strip().lower() + if not normalized: + return random.choice(_NON_MATH_REDIRECT_RESPONSES) + + if is_math_related_query(normalized): + return None + + if _GREETING_PATTERN.search(normalized): + return random.choice(_GREETING_RESPONSES) + + if _THANKS_PATTERN.search(normalized): + return random.choice(_THANKS_RESPONSES) + + return random.choice(_NON_MATH_REDIRECT_RESPONSES) + + +def get_scope_boundary_response(message: str, history: Optional[Sequence[Any]] = None) -> Optional[str]: + boundary_response = _scope_boundary_response_without_continuation(message) + if boundary_response is None: + return None + + if not _is_continuation_followup_token(message): + return boundary_response + + if _is_contextual_continuation_followup(message, history): + return None + + reconstructed_intent = _extract_latest_user_intent_message(history) + if not reconstructed_intent: + return _CONTINUATION_CONTEXT_CLARIFY_RESPONSE + + reconstructed_boundary_response = _scope_boundary_response_without_continuation(reconstructed_intent) + if reconstructed_boundary_response is None: + return None + + if ( + reconstructed_boundary_response in _GREETING_RESPONSES + or reconstructed_boundary_response in _THANKS_RESPONSES + ): + return _CONTINUATION_CONTEXT_CLARIFY_RESPONSE + + return reconstructed_boundary_response + + +def build_math_tutor_prompt(question: str) -> str: + """Build a structured math-tutor prompt for the LLM.""" + return f"""SYSTEM: +You are MathPulse Tutor, a precise and patient math tutor for Filipino senior high school STEM students. +Your job is to: +1) Understand the student's math question (algebra, functions, graphs, trigonometry, analytic geometry, basic calculus, statistics, or word problems). +2) Solve the problem step by step, explaining each transformation in simple language. +3) Show all important equations clearly and avoid skipping algebra steps unless obvious to a Grade 11–12 STEM student. +4) At the end, restate the final answer explicitly (e.g., "Final answer: x = 3"). +5) If the question is ambiguous or missing information, ask a short clarifying question first instead of guessing. +6) If the student makes a mistake, point it out gently, explain why it is wrong, and show the correct method. +7) Never invent new notation or definitions; use standard high-school math notation only. +8) When there are multiple possible methods, briefly mention alternatives but pick one main method and follow it consistently. +9) If the computation is long, summarize intermediate results so the student does not get lost. +10) If the answer depends on approximations, specify whether the result is exact or rounded (and to how many decimal places). + +IMPORTANT - LaTeX Math Formatting: +- Inline math: use $...$ (e.g., $x^2 + y^2 = z^2$) +- Display math: use $$...$$ (e.g., $$\\int_a^b f(x) , dx$$) +- Never use \\( \\) or \\[ \\] delimiters - use $ and $$ instead +- Never use square brackets like [equation] for math - use proper LaTeX +- Always put math in complete sentences + +Speak in clear, concise English. Use short paragraphs. +If the user question is not about math, politely and briefly redirect them to ask a math question. +If the user sends a greeting or thanks, reply warmly, then invite a math question. + +USER: +Student question: +{question} +""" + + +def call_math_tutor_llm(question: str) -> str: + """Convenience wrapper: call the HF serverless model with the MathPulse tutor prompt via chat completions.""" + prompt = build_math_tutor_prompt(question) + messages = [{"role": "user", "content": prompt}] + return call_hf_chat(messages, max_tokens=CHAT_MAX_NEW_TOKENS, temperature=0.2, top_p=0.9, task_type="chat") + + +# ─── Request/Response Models ────────────────────────────────── + + +class ChatMessage(BaseModel): + role: str = Field(..., description="'user' or 'assistant'") + content: str + + +class ChatRequest(BaseModel): + message: str + history: List[ChatMessage] = Field(default_factory=list) + userId: Optional[str] = None + verify: bool = Field(default=False, description="Enable self-consistency verification for math answers") + expectedEndMarker: Optional[str] = Field( + default=None, + description="Optional marker that should appear in the completed answer (for continuation checks).", + ) + completionMode: str = Field( + default="auto", + description="Completion check mode: auto, marker, or none.", + ) + continuationMaxRounds: Optional[int] = Field( + default=None, + ge=0, + le=8, + description="Optional override for max backend continuation rounds.", + ) + + @field_validator("completionMode") + @classmethod + def validate_completion_mode(_cls, value: str) -> str: + normalized = (value or "").strip().lower() + if normalized in {"auto", "marker", "none"}: + return normalized + return "auto" + + +class ChatResponse(BaseModel): + response: str + verified: Optional[bool] = None + confidence: Optional[str] = None + warning: Optional[str] = None + + +class StudentRiskData(BaseModel): + engagementScore: float = Field(..., ge=0, le=100) + avgQuizScore: float = Field(..., ge=0, le=100) + attendance: float = Field(..., ge=0, le=100) + assignmentCompletion: float = Field(..., ge=0, le=100) + + +class RiskPrediction(BaseModel): + riskLevel: str + confidence: float + analysis: dict + risk_level: str + risk_score: float + top_factors: List[str] + + +class BatchRiskRequest(BaseModel): + students: List[StudentRiskData] + + +class LearningPathRequest(BaseModel): + weaknesses: List[str] + gradeLevel: str + learningStyle: Optional[str] = "visual" + subject: Optional[str] = None + + +class LearningPathResponse(BaseModel): + learningPath: str + + +class StudentInsightData(BaseModel): + name: str + engagementScore: float + avgQuizScore: float + attendance: float + riskLevel: str + + +class DailyInsightRequest(BaseModel): + students: List[StudentInsightData] + + +class DailyInsightResponse(BaseModel): + insight: str + + +class VerificationResult(BaseModel): + verified: bool + confidence: str + response: str + warning: Optional[str] = None + + +class CodeVerificationResult(BaseModel): + verified: bool + code: str + output: str + error: Optional[str] = None + + +class LLMJudgeResult(BaseModel): + correct: bool + issues: List[str] + confidence: float + + +class VerifySolutionRequest(BaseModel): + problem: str + solution: str + + +class VerifySolutionResponse(BaseModel): + overall_verified: bool + aggregated_confidence: float + self_consistency: Optional[VerificationResult] = None + code_verification: Optional[CodeVerificationResult] = None + llm_judge: Optional[LLMJudgeResult] = None + warnings: List[str] = Field(default_factory=list) + + +class TestingResetRequest(BaseModel): + role: Optional[str] = Field( + default=None, + description="Optional role hint. Must match the authenticated role when provided.", + ) + lrn: Optional[str] = Field( + default=None, + description="Optional student LRN used by legacy testing datasets.", + ) + + +class TestingResetResponse(BaseModel): + role: str + deletedDocs: int + updatedDocs: int + summary: str + + +# ─── Routes ──────────────────────────────────────────────────── + + +@app.get("/health") +async def health_check(): + chat_model = CHAT_MODEL + try: + routing_client = get_inference_client() + chat_model, _ = routing_client._resolve_primary_model( + InferenceRequest( + messages=[{"role": "user", "content": "health_check"}], + task_type="chat", + ) + ) + except Exception: + pass + + return {"status": "healthy", "models": {"chat": chat_model, "risk": RISK_MODEL}} + + +@app.get("/") +async def root(): + return { + "name": "MathPulse AI API", + "version": "1.0.0", + "docs": "/docs", + "health": "/health", + } + + +# ─── AI Chat Tutor ───────────────────────────────────────────── + + +MATH_TUTOR_SYSTEM_PROMPT = """You are Pulse, MathPulse AI's friendly math tutor for Filipino Senior High School +students. You help students understand and solve problems in General Mathematics, +Business Mathematics, Statistics & Probability, and Finite Mathematics, all aligned +with the DepEd Strengthened SHS Curriculum and SDO Navotas learning modules. + +YOUR BEHAVIOR RULES: +1. PERSONALIZE every response. Address the student by first name occasionally. +2. NEVER give direct answers to quiz or exam items — guide with hints and questions instead. +3. If the student is struggling on a critical gap topic, gently steer them back to + prerequisite concepts before moving forward. +4. Use the SDO Navotas step-by-step method for ALL solutions: + "Given → Formula → Substitute → Compute → Conclude" +5. Always format math using LaTeX: + - Inline: \\( expression \\) + - Block/display: \\[ expression \\] + Never use dollar signs ($) — they break the KaTeX renderer. +6. Use Filipino-friendly English. Mix in occasional Tagalog phrases + (e.g., "Kaya mo yan!", "Subukan natin...") to keep the tone warm. +7. When a student answers a "try_it" problem, evaluate their answer: + - If correct: Celebrate briefly, explain WHY it's correct, then offer a harder challenge. + - If wrong: Say "Good try! Let's check your steps..." then walk through the error. +8. Keep responses concise (max 300 words per message). Use bullet points for steps. +9. If a student asks about a topic outside their current lesson, help but + note: "This is from [topic]. We'll cover this soon in your learning path!" +10. NEVER generate quiz items with answers visible to the student. +11. When you detect the student consistently making the same mistake, + note it clearly: "I noticed you keep forgetting to convert % to decimal first — let's fix that!" + +RESPONSE FORMAT FOR MATH EXPLANATIONS: +1. Quick concept recap (1-2 sentences) +2. Formula (in LaTeX block) +3. Step-by-step solution +4. Final answer with units/peso sign +5. One quick follow-up question to check understanding + +AWARENESS OF FULL CURRICULUM: +You have complete knowledge of all topics in the MathPulse topic registry +(NA-*, BM-*, SP-*, FM1-*, FM2-* topic codes). When a student asks "what's next?" +refer to their suggested_learning_path from the diagnostic result.""" + + +_STREAM_COMPLETION_MODES: Set[str] = {"auto", "marker", "none"} +_EXPECTED_END_MARKER_PATTERNS: Tuple[re.Pattern[str], ...] = ( + re.compile( + r"\b(?:end|finish|stop)\s+with(?:\s+the\s+(?:exact\s+)?(?:marker|text))?\s*[:\-]?\s*([\"'`]?)([A-Za-z0-9_:\-]{2,96})\1", + re.IGNORECASE, + ), + re.compile( + r"\b(?:include|append)\s+(?:the\s+)?marker\s*[:\-]?\s*([\"'`]?)([A-Za-z0-9_:\-]{2,96})\1", + re.IGNORECASE, + ), +) +_EXPECTED_RANGE_PATTERNS: Tuple[re.Pattern[str], ...] = ( + re.compile(r"\bfor\s+([a-zA-Z])\s*=\s*(-?\d+)\s*(?:\.\.|to)\s*(-?\d+)\b", re.IGNORECASE), + re.compile(r"\b([a-zA-Z])\s*=\s*(-?\d+)\s*\.\.\s*(-?\d+)\b", re.IGNORECASE), +) + + +def _normalize_expected_end_marker(marker: Optional[str]) -> Optional[str]: + if marker is None: + return None + + normalized = str(marker).strip() + if not normalized: + return None + + if len(normalized) >= 2 and normalized[0] == normalized[-1] and normalized[0] in {'"', "'", "`"}: + normalized = normalized[1:-1].strip() + + normalized = normalized.rstrip(".,; ") + if not normalized: + return None + + return normalized[:120] + + +def _extract_expected_end_marker_from_prompt(prompt: str) -> Optional[str]: + source = (prompt or "").strip() + if not source: + return None + + for pattern in _EXPECTED_END_MARKER_PATTERNS: + match = pattern.search(source) + if not match: + continue + marker = _normalize_expected_end_marker(match.group(2)) + if marker: + return marker + + return None + + +def _contains_expected_end_marker(answer: str, marker: Optional[str]) -> bool: + marker_value = _normalize_expected_end_marker(marker) + if not marker_value: + return False + return marker_value.lower() in (answer or "").lower() + + +def _extract_expected_prompt_range(prompt: str) -> Optional[Tuple[str, int, int]]: + source = (prompt or "").strip() + if not source: + return None + + for pattern in _EXPECTED_RANGE_PATTERNS: + match = pattern.search(source) + if not match: + continue + + var_name = match.group(1).lower() + start_value = int(match.group(2)) + end_value = int(match.group(3)) + if abs(end_value - start_value) > 4000: + return None + return var_name, start_value, end_value + + return None + + +def _response_covers_expected_range(answer: str, range_spec: Optional[Tuple[str, int, int]]) -> bool: + if range_spec is None: + return True + + var_name, start_value, end_value = range_spec + sequence_pattern = re.compile(rf"\b{re.escape(var_name)}\s*=\s*(-?\d+)\b", re.IGNORECASE) + matches = [int(match) for match in sequence_pattern.findall(answer or "")] + if not matches: + return False + + terminal_value = matches[-1] + if end_value >= start_value: + return terminal_value >= end_value + return terminal_value <= end_value + + +def _response_looks_truncated(answer: str) -> bool: + trimmed = (answer or "").strip() + if not trimmed: + return True + + trailing_signals = [ + r"```[^`]*$", + r"\$\$[^$]*$", + r"\$[^$\n]*$", + r"\\\[[^\]]*$", + r"\\\([^\)]*$", + r"\\boxed\{[^}]*$", + r"(?:Step\s*\d+[:.]?)\s*$", + r"(?:Final\s*Answer[:.]?)\s*$", + ] + if any(re.search(pattern, trimmed, re.IGNORECASE) for pattern in trailing_signals): + return True + + if len(trimmed) >= 96 and re.search( + r"\b(?:and|or|but|because|since|so|then|which|that|where|when|with|for|to|from|of|in|on|at|by)\s*$", + trimmed, + re.IGNORECASE, + ): + return True + + return False + + +def _normalize_stream_completion_mode(mode: Optional[str]) -> str: + candidate = (mode or "").strip().lower() + if candidate in _STREAM_COMPLETION_MODES: + return candidate + if CHAT_STREAM_COMPLETION_MODE_DEFAULT in _STREAM_COMPLETION_MODES: + return CHAT_STREAM_COMPLETION_MODE_DEFAULT + return "auto" + + +def _resolve_stream_continuation_rounds(request: ChatRequest, completion_mode: str) -> int: + if not CHAT_STREAM_CONTINUATION_ENABLED: + return 0 + if completion_mode == "none": + return 0 + if request.continuationMaxRounds is None: + return CHAT_STREAM_CONTINUATION_MAX_ROUNDS + return max(0, min(int(request.continuationMaxRounds), 8)) + + +def _should_continue_stream_response( + *, + prompt: str, + accumulated_answer: str, + completion_mode: str, + expected_end_marker: Optional[str], + expected_range: Optional[Tuple[str, int, int]], +) -> bool: + mode = _normalize_stream_completion_mode(completion_mode) + if mode == "none": + return False + + marker_present = _contains_expected_end_marker(accumulated_answer, expected_end_marker) + if mode == "marker": + return not marker_present + + if expected_end_marker and not marker_present: + return True + + if expected_range and not _response_covers_expected_range(accumulated_answer, expected_range): + return True + + if _response_looks_truncated(accumulated_answer): + return True + + return False + + +def _merge_answer_continuation(base: str, continuation: str) -> str: + base_trimmed = (base or "").strip() + continuation_trimmed = (continuation or "").strip() + + if not base_trimmed: + return continuation_trimmed + if not continuation_trimmed: + return base_trimmed + + max_overlap = min(len(base_trimmed), len(continuation_trimmed), 220) + for overlap in range(max_overlap, 23, -1): + if base_trimmed.endswith(continuation_trimmed[:overlap]): + return f"{base_trimmed}{continuation_trimmed[overlap:]}".strip() + + if base_trimmed.endswith(continuation_trimmed): + return base_trimmed + if continuation_trimmed.startswith(base_trimmed): + return continuation_trimmed + + return f"{base_trimmed}\n\n{continuation_trimmed}".strip() + + +def _build_stream_continuation_prompt(original_question: str, expected_end_marker: Optional[str]) -> str: + lines = [ + "Continue the exact same answer from where it stopped.", + "Do not restart. Do not repeat content that was already sent.", + "Only output the missing continuation.", + f"Original student question: {original_question}", + ] + marker = _normalize_expected_end_marker(expected_end_marker) + if marker: + lines.append(f'Include the exact marker "{marker}" at the very end once complete.') + return "\n".join(lines) + + +@app.post("/api/chat", response_model=ChatResponse) +async def chat_tutor(request: ChatRequest): + """AI Math Tutor powered by Hugging Face Inference routing.""" + try: + boundary_response = get_scope_boundary_response(request.message, request.history) + if boundary_response is not None: + return ChatResponse(response=boundary_response) + + system_prompt = MATH_TUTOR_SYSTEM_PROMPT + + if request.userId and HAS_FIREBASE_ADMIN and firebase_firestore: + try: + db = firebase_firestore.client() + user_doc = db.collection("users").document(request.userId).get() + if user_doc.exists: + user_data = user_doc.to_dict() or {} + diag_id = user_data.get("latestDiagnosticTestId", "") + if diag_id: + diag_doc = db.collection("diagnosticResults").document(request.userId).collection("attempts").document(diag_id).get() + if diag_doc.exists: + diag_data = diag_doc.to_dict() or {} + risk = diag_data.get("riskProfile", {}) + student_context = f""" +STUDENT PROFILE: +Name: {user_data.get('displayName', 'Student')} +Strand: {diag_data.get('strand', 'STEM')} +Weak Domains: {', '.join(risk.get('weak_domains', []))} +Critical Gaps: {', '.join(risk.get('critical_gaps', []))} +Overall Risk Level: {risk.get('overall_risk', 'unknown')} +""" + system_prompt = student_context + "\n" + system_prompt + except Exception as ctx_err: + logger.debug(f"Failed to inject student profile into chat: {ctx_err}") + + try: + curriculum_chunks = retrieve_curriculum_context( + query=request.message[:200], + top_k=2, + ) + if curriculum_chunks: + rag_context = "RELEVANT CURRICULUM REFERENCE:\n" + for chunk in curriculum_chunks: + rag_context += f"[{chunk.get('source_file', '')}] {chunk.get('content', '')[:400]}\n--\n" + system_prompt = rag_context + "\n\n" + system_prompt + except Exception as rag_err: + logger.debug(f"RAG context injection skipped: {rag_err}") + + messages = [{"role": "system", "content": system_prompt}] + + # Add conversation history + for msg in request.history[-10:]: # Keep last 10 messages for context window + messages.append({"role": msg.role, "content": msg.content}) + + # Add current message + messages.append({"role": "user", "content": request.message}) + + # Call HF serverless with retry (handled inside call_hf_chat) + try: + answer = await call_hf_chat_async( + messages, + max_tokens=CHAT_MAX_NEW_TOKENS, + temperature=0.3, + top_p=0.85, + task_type="chat", + ) + except Exception as hf_err: + logger.error(f"HF chat failed: {hf_err}") + raise HTTPException( + status_code=502, + detail="AI model service is temporarily unavailable. Please try again.", + ) + + # Optional self-consistency verification + if request.verify: + logger.info("Running self-consistency verification for chat response") + verification = await verify_math_response(request.message, messages) + return ChatResponse( + response=verification["response"], + verified=verification["verified"], + confidence=verification["confidence"], + warning=verification.get("warning"), + ) + + return ChatResponse(response=answer) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Chat error: {e}") + raise HTTPException(status_code=500, detail=f"Chat service error: {str(e)}") + + +@app.post("/api/chat/stream") +async def chat_tutor_stream(request: ChatRequest): + """SSE stream endpoint for AI Math Tutor chat responses.""" + try: + boundary_response = get_scope_boundary_response(request.message, request.history) + messages = [{"role": "system", "content": MATH_TUTOR_SYSTEM_PROMPT}] + for msg in request.history[-10:]: + messages.append({"role": msg.role, "content": msg.content}) + messages.append({"role": "user", "content": request.message}) + + def _sse(event: str, data: str) -> str: + lines = str(data).replace("\r\n", "\n").replace("\r", "\n").split("\n") + body = [f"event: {event}"] + [f"data: {line}" for line in lines] + return "\n".join(body) + "\n\n" + + async def event_generator(): + if boundary_response is not None: + payload = json.dumps({"chunk": boundary_response}, ensure_ascii=False) + yield _sse("chunk", payload) + yield _sse("end", "done") + return + + stream_started_at = time.monotonic() + emitted_any_chunk = False + completion_mode = _normalize_stream_completion_mode(request.completionMode) + expected_end_marker = _normalize_expected_end_marker(request.expectedEndMarker) + if not expected_end_marker: + expected_end_marker = _extract_expected_end_marker_from_prompt(request.message) + if completion_mode == "marker" and not expected_end_marker: + completion_mode = "auto" + + expected_range = _extract_expected_prompt_range(request.message) if completion_mode == "auto" else None + continuation_rounds = _resolve_stream_continuation_rounds(request, completion_mode) + max_attempts = 1 + continuation_rounds + assembled_response = "" + + try: + for attempt_index in range(max_attempts): + response_len_before_attempt = len(assembled_response) + if attempt_index == 0: + attempt_messages = messages + else: + continuation_messages = list(messages) + tail_context = assembled_response[-CHAT_STREAM_CONTINUATION_TAIL_CHARS:] + if tail_context: + continuation_messages.append({"role": "assistant", "content": tail_context}) + continuation_messages.append({ + "role": "user", + "content": _build_stream_continuation_prompt(request.message, expected_end_marker), + }) + attempt_messages = continuation_messages + + stream_iterator = call_hf_chat_stream_async( + attempt_messages, + max_tokens=CHAT_MAX_NEW_TOKENS, + temperature=0.3, + top_p=0.85, + task_type="chat", + ) + + attempt_chunks: List[str] = [] + + while True: + elapsed = time.monotonic() - stream_started_at + remaining_total = CHAT_STREAM_TOTAL_TIMEOUT_SEC - elapsed + if remaining_total <= 0: + raise TimeoutError("Chat stream exceeded total timeout") + + token_timeout = min(CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC, remaining_total) + try: + chunk = await asyncio.wait_for(stream_iterator.__anext__(), timeout=token_timeout) + except StopAsyncIteration: + break + + if not chunk: + continue + + chunk_text = str(chunk) + attempt_chunks.append(chunk_text) + + if attempt_index == 0: + assembled_response += chunk_text + emitted_any_chunk = True + payload = json.dumps({"chunk": chunk_text}, ensure_ascii=False) + yield _sse("chunk", payload) + await asyncio.sleep(0) + + attempt_text = "".join(attempt_chunks) + if attempt_index > 0 and attempt_text: + merged_response = _merge_answer_continuation(assembled_response, attempt_text) + if merged_response.startswith(assembled_response): + delta_text = merged_response[len(assembled_response):] + else: + delta_text = attempt_text + merged_response = _merge_answer_continuation(assembled_response, delta_text) + + assembled_response = merged_response + if delta_text: + emitted_any_chunk = True + payload = json.dumps({"chunk": delta_text}, ensure_ascii=False) + yield _sse("chunk", payload) + await asyncio.sleep(0) + + added_chars = len(assembled_response) - response_len_before_attempt + should_continue = _should_continue_stream_response( + prompt=request.message, + accumulated_answer=assembled_response, + completion_mode=completion_mode, + expected_end_marker=expected_end_marker, + expected_range=expected_range, + ) + + if not should_continue: + break + + if attempt_index >= continuation_rounds: + logger.info( + "Reached chat stream continuation limit (rounds=%s mode=%s marker=%s)", + continuation_rounds, + completion_mode, + expected_end_marker or "", + ) + break + + if attempt_index > 0 and added_chars < CHAT_STREAM_CONTINUATION_MIN_NEW_CHARS: + logger.info( + "Stopping chat stream continuation due to low progress (added_chars=%s min_required=%s)", + added_chars, + CHAT_STREAM_CONTINUATION_MIN_NEW_CHARS, + ) + break + + logger.info( + "Continuing chat stream response (attempt=%s mode=%s marker=%s range=%s)", + attempt_index + 2, + completion_mode, + expected_end_marker or "", + expected_range or "", + ) + except (asyncio.TimeoutError, TimeoutError): + logger.error( + "HF chat stream timed out (idle=%ss total=%ss)", + CHAT_STREAM_NO_TOKEN_TIMEOUT_SEC, + CHAT_STREAM_TOTAL_TIMEOUT_SEC, + ) + err_payload = json.dumps({ + "detail": ( + "AI response stream timed out mid-response. Please retry." + if emitted_any_chunk + else "AI response stream timed out before any tokens were received. Please retry." + ), + }) + yield _sse("error", err_payload) + except Exception as hf_err: + logger.error(f"HF chat stream failed: {hf_err}") + err_payload = json.dumps({ + "detail": "AI model service is temporarily unavailable. Please try again.", + }) + yield _sse("error", err_payload) + finally: + yield _sse("end", "done") + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + except Exception as e: + logger.error(f"Chat stream setup error: {e}") + raise HTTPException(status_code=500, detail=f"Chat stream setup error: {str(e)}") + + +# ─── Verification System ────────────────────────────────────── + + +def _extract_final_answer(text: str) -> Optional[str]: + """Extract the final numeric/symbolic answer from a math response.""" + # Try to find explicitly labeled final answers + patterns = [ + r"\*\*Final Answer[:\s]*(.+?)\*\*", + r"Final Answer[:\s]*(.+?)[\n\.]", + r"(?:the answer is|= )\s*(.+?)[\n\.\s]", + r"\\boxed\{(.+?)\}", + ] + for pat in patterns: + match = re.search(pat, text, re.IGNORECASE) + if match: + return match.group(1).strip().rstrip(".") + # Fallback: last line with an equals sign + for line in reversed(text.strip().splitlines()): + if "=" in line: + parts = line.split("=") + return parts[-1].strip().rstrip(".") + return None + + +async def verify_math_response( + problem: str, base_messages: List[Dict[str, Any]] +) -> Dict[str, Any]: + """ + Self-consistency verification: generate multiple responses to the same + math problem and check if the final answers agree. + Returns dict with 'verified' (bool), 'confidence' (str), and 'response'. + """ + responses: List[str] = [] + answers: List[Optional[str]] = [] + + logger.info(f"Generating {VERIFICATION_SAMPLES} responses for self-consistency check") + + for i in range(VERIFICATION_SAMPLES): + try: + text = await call_hf_chat_async( + base_messages, + max_tokens=2048, + temperature=0.7, + top_p=0.9, + task_type="verify_solution", + ) + responses.append(text) + answers.append(_extract_final_answer(text)) + logger.info(f" Sample {i+1} answer: {answers[-1]}") + except Exception as e: + logger.warning(f" Sample {i+1} failed: {e}") + responses.append("") + answers.append(None) + + # Check agreement among non-None answers + valid_answers = [a for a in answers if a is not None] + + if not valid_answers: + return { + "verified": False, + "confidence": "low", + "response": responses[0] if responses else "", + "warning": "Could not extract answers for verification.", + } + + counter = Counter(valid_answers) + most_common_answer, most_common_count = counter.most_common(1)[0] + agreement_ratio = most_common_count / len(valid_answers) + + if agreement_ratio >= 1.0: + confidence = "high" + verified = True + elif agreement_ratio >= 0.6: + confidence = "medium" + verified = True + else: + confidence = "low" + verified = False + + # Pick the response whose answer matches the majority + best_response = responses[0] + for resp, ans in zip(responses, answers): + if ans == most_common_answer and resp: + best_response = resp + break + + result: Dict[str, Any] = { + "verified": verified, + "confidence": confidence, + "response": best_response, + } + + if not verified: + result["warning"] = ( + f"Self-consistency check failed: answers did not converge " + f"({len(set(valid_answers))} distinct answers from {len(valid_answers)} samples). " + f"This answer may be unreliable." + ) + + logger.info(f"Self-consistency result: verified={verified}, confidence={confidence}") + return result + + +async def verify_with_code(problem: str, solution: str) -> Dict[str, Any]: + """ + Ask the model to generate Python verification code for a math solution, + execute it safely, and return the verification result. + """ + + prompt = f"""Given this math problem and its proposed solution, write a short Python script that numerically verifies the answer. + +**Problem:** {problem} + +**Proposed Solution:** {solution} + +Rules: +- Use only the Python standard library and the `math` module. +- The script must print EXACTLY one line: either "VERIFIED" if the solution is correct, or "FAILED: " if it is wrong. +- Do NOT use input(), networking, file I/O, or any system calls. +- Keep the script under 30 lines. + +Respond with ONLY the Python code, no markdown fences, no explanation.""" + + try: + raw_code = await call_hf_chat_async( + messages=[ + { + "role": "system", + "content": "You are a Python code generator. Output only valid Python code, nothing else.", + }, + {"role": "user", "content": prompt}, + ], + max_tokens=800, + temperature=0.1, + ) + # Strip markdown code fences if present + code = re.sub(r"^```(?:python)?\s*\n?", "", raw_code.strip()) + code = re.sub(r"\n?```\s*$", "", code) + code = code.strip() + + if not code: + return {"verified": False, "code": "", "output": "", "error": "No code generated"} + + logger.info("Executing verification code in isolated subprocess sandbox") + + code_blocklist = re.compile( + r"(__\w+__|\bimport\b|exec\s*\(|eval\s*\(|open\s*\(|os\.|sys\.|subprocess|socket|pathlib|shutil|input\s*\(|compile\s*\(|globals\s*\(|locals\s*\(|__builtins__)", + re.IGNORECASE, + ) + if code_blocklist.search(code): + return { + "verified": False, + "code": code, + "output": "", + "error": "Generated code contains disallowed operations", + } + if len(code) > 3000: + return { + "verified": False, + "code": code, + "output": "", + "error": "Generated code exceeded maximum allowed length", + } + + wrapper_script = f""" +import io +import json +import math +import contextlib + +try: + import resource + _max_mem = 256 * 1024 * 1024 + resource.setrlimit(resource.RLIMIT_CPU, (2, 2)) + resource.setrlimit(resource.RLIMIT_AS, (_max_mem, _max_mem)) +except Exception: + pass + +SAFE_BUILTINS = {{ + "print": print, + "range": range, + "len": len, + "abs": abs, + "round": round, + "int": int, + "float": float, + "str": str, + "sum": sum, + "min": min, + "max": max, + "enumerate": enumerate, + "zip": zip, + "map": map, + "list": list, + "tuple": tuple, + "dict": dict, + "set": set, + "sorted": sorted, + "pow": pow, + "isinstance": isinstance, + "True": True, + "False": False, + "None": None, +}} + +payload = {{"ok": False, "output": "", "error": None}} +stdout_capture = io.StringIO() +source = {json.dumps(code)} +try: + compiled = compile(source, "", "exec") + with contextlib.redirect_stdout(stdout_capture): + exec(compiled, {{"__builtins__": SAFE_BUILTINS, "math": math}}, {{}}) + payload["ok"] = True +except Exception as exc: + payload["error"] = str(exc) + +payload["output"] = stdout_capture.getvalue().strip() +print(json.dumps(payload)) +""".strip() + + temp_path = "" + try: + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False, encoding="utf-8") as tf: + tf.write(wrapper_script) + temp_path = tf.name + + completed = subprocess.run( + [sys.executable, "-I", "-S", temp_path], + capture_output=True, + text=True, + timeout=3, + ) + except subprocess.TimeoutExpired: + return { + "verified": False, + "code": code, + "output": "", + "error": "Code execution timed out", + } + finally: + if temp_path and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + pass + + if completed.returncode != 0: + stderr_text = (completed.stderr or "").strip() + return { + "verified": False, + "code": code, + "output": "", + "error": f"Sandbox execution failed: {stderr_text[:300]}", + } + + lines = (completed.stdout or "").strip().splitlines() + if not lines: + return { + "verified": False, + "code": code, + "output": "", + "error": "Sandbox execution produced no output", + } + + try: + payload = json.loads(lines[-1]) + except json.JSONDecodeError: + return { + "verified": False, + "code": code, + "output": "", + "error": "Sandbox returned malformed payload", + } + + output = str(payload.get("output", "")).strip() + sandbox_error = payload.get("error") + if sandbox_error: + return { + "verified": False, + "code": code, + "output": output, + "error": f"Code execution error: {sandbox_error}", + } + + verified = output.upper().startswith("VERIFIED") + logger.info(f"Code verification output: {output}") + + return { + "verified": verified, + "code": code, + "output": output, + "error": None, + } + + except Exception as e: + logger.error(f"Code verification error: {e}") + return {"verified": False, "code": "", "output": "", "error": str(e)} + + +async def llm_judge_verification(problem: str, solution: str) -> Dict[str, Any]: + """ + Use a second LLM call with low temperature to judge whether a math + solution is correct. Checks formula usage, calculations, and logic. + Returns dict with 'correct' (bool), 'issues' (list), 'confidence' (float). + """ + + prompt = f"""You are a meticulous math verification expert. Your job is to verify whether the following solution to a math problem is mathematically correct. + +**Problem:** {problem} + +**Solution to verify:** +{solution} + +Carefully check: +1. Are the correct formulas and theorems applied? +2. Is every arithmetic calculation accurate? +3. Is the logical reasoning valid at each step? +4. Is the final answer correct and in the right units? + +Respond with ONLY a JSON object (no markdown, no explanation outside the JSON): +{{ + "correct": true or false, + "issues": ["list of specific errors or concerns, empty if correct"], + "confidence": 0.0 to 1.0 +}}""" + + try: + raw = await call_hf_chat_async( + messages=[ + { + "role": "system", + "content": "You are a mathematical verification judge. Respond ONLY with valid JSON.", + }, + {"role": "user", "content": prompt}, + ], + task_type="verify_solution", + max_tokens=500, + temperature=0.1, + ) + # Extract JSON from response + json_start = raw.find("{") + json_end = raw.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + parsed = json.loads(raw[json_start:json_end]) + else: + logger.warning(f"LLM judge returned non-JSON: {raw[:200]}") + return {"correct": False, "issues": ["Could not parse judge response"], "confidence": 0.0} + + judge_result = { + "correct": bool(parsed.get("correct", False)), + "issues": list(parsed.get("issues", [])), + "confidence": float(parsed.get("confidence", 0.0)), + } + + logger.info(f"LLM judge result: correct={judge_result['correct']}, confidence={judge_result['confidence']}") + return judge_result + + except Exception as e: + logger.error(f"LLM judge error: {e}\n{traceback.format_exc()}") + return {"correct": False, "issues": [f"Judge error: {str(e)}"], "confidence": 0.0} + + +# ─── Verification Endpoint ──────────────────────────────────── + + +@app.post("/api/verify-solution", response_model=VerifySolutionResponse) +async def verify_solution(request: VerifySolutionRequest, response: Response): + """ + Run all 3 verification methods on a problem+solution pair: + 1. Self-consistency (multiple samples) + 2. Code-based verification + 3. LLM judge review + Returns aggregated confidence and per-method results. + """ + try: + logger.info(f"Running full verification for problem: {request.problem[:80]}...") + cache_key = deterministic_response_cache.build_cache_key( + "verify_solution", + request.model_dump(), + ) + cached_payload = await deterministic_response_cache.get(cache_key) + if isinstance(cached_payload, dict): + _set_cache_response_header(response, hit=True) + return VerifySolutionResponse(**cached_payload) + + _set_cache_response_header(response, hit=False) + warnings: List[str] = [] + + # Build messages for self-consistency check + messages = [ + {"role": "system", "content": MATH_TUTOR_SYSTEM_PROMPT}, + {"role": "user", "content": request.problem}, + ] + + # 1. Self-consistency check + try: + sc_result = await verify_math_response(request.problem, messages) + sc_model = VerificationResult( + verified=sc_result["verified"], + confidence=sc_result["confidence"], + response=sc_result["response"], + warning=sc_result.get("warning"), + ) + if sc_result.get("warning"): + warnings.append(f"Self-consistency: {sc_result['warning']}") + except Exception as e: + logger.error(f"Self-consistency verification failed: {e}") + sc_model = VerificationResult( + verified=False, confidence="low", response="", warning=str(e) + ) + warnings.append(f"Self-consistency check failed: {str(e)}") + + # 2. Code verification + try: + cv_result = await verify_with_code(request.problem, request.solution) + cv_model = CodeVerificationResult( + verified=cv_result["verified"], + code=cv_result.get("code", ""), + output=cv_result.get("output", ""), + error=cv_result.get("error"), + ) + if cv_result.get("error"): + warnings.append(f"Code verification: {cv_result['error']}") + except Exception as e: + logger.error(f"Code verification failed: {e}") + cv_model = CodeVerificationResult( + verified=False, code="", output="", error=str(e) + ) + warnings.append(f"Code verification failed: {str(e)}") + + # 3. LLM judge + try: + lj_result = await llm_judge_verification(request.problem, request.solution) + lj_model = LLMJudgeResult( + correct=lj_result["correct"], + issues=lj_result["issues"], + confidence=lj_result["confidence"], + ) + if lj_result["issues"]: + warnings.append(f"LLM judge issues: {'; '.join(lj_result['issues'])}") + except Exception as e: + logger.error(f"LLM judge verification failed: {e}") + lj_model = LLMJudgeResult(correct=False, issues=[str(e)], confidence=0.0) + warnings.append(f"LLM judge failed: {str(e)}") + + # Aggregate confidence score (0.0 - 1.0) + scores: List[float] = [] + + # Self-consistency score + sc_score_map = {"high": 1.0, "medium": 0.6, "low": 0.2} + scores.append(sc_score_map.get(sc_model.confidence, 0.2)) + + # Code verification score + scores.append(1.0 if cv_model.verified else 0.0) + + # LLM judge score + scores.append(lj_model.confidence if lj_model.correct else (1.0 - lj_model.confidence) * 0.3) + + aggregated = round(sum(scores) / len(scores), 3) if scores else 0.0 + overall_verified = aggregated >= 0.6 + + logger.info( + f"Verification complete: overall_verified={overall_verified}, " + f"aggregated_confidence={aggregated}" + ) + + result = VerifySolutionResponse( + overall_verified=overall_verified, + aggregated_confidence=aggregated, + self_consistency=sc_model, + code_verification=cv_model, + llm_judge=lj_model, + warnings=warnings, + ) + await deterministic_response_cache.set( + cache_key, + result.model_dump(), + VERIFY_SOLUTION_CACHE_TTL_SECONDS, + ) + return result + + except Exception as e: + logger.error(f"Verify solution error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Verification error: {str(e)}") + + +# ─── Student Risk Classification (DeepSeek) ─── + + +RISK_LABELS = [ + "high risk of failing", + "medium academic risk", + "low risk academically stable", +] + +RISK_MAPPING = { + "high risk of failing": "High", + "medium academic risk": "Medium", + "low risk academically stable": "Low", +} + + +def _to_strict_risk_level(level: str) -> str: + normalized = (level or "").strip().lower() + if normalized in {"high", "medium", "low"}: + return normalized + return "medium" + + +def _basic_risk_top_factors(student_data: StudentRiskData) -> List[str]: + factors: List[str] = [] + if student_data.avgQuizScore < 55: + factors.append("Low average quiz performance") + if student_data.assignmentCompletion < 65: + factors.append("Incomplete assignment submission trend") + if student_data.engagementScore < 50: + factors.append("Low class engagement") + if not factors: + factors.append("No major risk indicators detected") + return factors[:3] + + +def _parse_recommendation_lines(text: str, *, max_items: int = 5) -> List[str]: + lines: List[str] = [] + for raw_line in (text or "").splitlines(): + line = raw_line.strip() + if not line: + continue + line = re.sub(r"^[-*\u2022\d\.\)\s]+", "", line).strip() + if not line: + continue + if len(line) < 8: + continue + lines.append(line) + + if not lines and text.strip(): + chunks = [chunk.strip() for chunk in re.split(r"[\n;]", text) if chunk.strip()] + lines = chunks + + deduped: List[str] = [] + seen: Set[str] = set() + for line in lines: + key = line.lower() + if key in seen: + continue + seen.add(key) + deduped.append(line) + if len(deduped) >= max_items: + break + return deduped + + +async def _generate_risk_recommendations_llm(data: EnhancedRiskRequest, result: EnhancedRiskPrediction) -> List[str]: + prompt = ( + "Generate concise teacher interventions for this student risk profile. " + "Return plain text with one recommendation per line. Avoid JSON and markdown.\n\n" + f"risk_level: {result.riskLevel}\n" + f"risk_score: {result.risk_score:.2f}\n" + f"engagementScore: {data.engagementScore:.1f}\n" + f"avgQuizScore: {data.avgQuizScore:.1f}\n" + f"assignmentCompletion: {data.assignmentCompletion:.1f}\n" + f"streak: {int(data.streak or 0)}\n" + f"daysSinceLastActivity: {int(data.daysSinceLastActivity or 0)}\n" + f"top_factors: {', '.join(result.top_factors)}" + ) + + content = await call_hf_chat_async( + messages=[ + { + "role": "system", + "content": ( + "You are a student success specialist. Provide practical, measurable, and age-appropriate " + "interventions for Grade 11-12 students." + ), + }, + {"role": "user", "content": prompt}, + ], + max_tokens=260, + temperature=0.2, + top_p=0.9, + task_type="risk_narrative", + timeout=120, + ) + return _parse_recommendation_lines(content, max_items=5) + + +@app.post("/api/predict-risk", response_model=RiskPrediction) +async def predict_risk(student_data: StudentRiskData, response: Response): + """Student risk prediction using DeepSeek AI classification""" + try: + cache_key = deterministic_response_cache.build_cache_key( + "predict_risk", + student_data.model_dump(), + ) + _set_cache_response_header(response, hit=False) + _ensure_deepseek_available() + + client = get_deepseek_client() + + risk_prompt = ( + f"Student academic performance summary: " + f"Engagement score is {student_data.engagementScore:.0f}%. " + f"Average quiz score is {student_data.avgQuizScore:.0f}%. " + f"Assignment completion rate is {student_data.assignmentCompletion:.0f}%.\n\n" + f"Classify this student into exactly one of these risk levels: {', '.join(RISK_LABELS)}. " + f"Respond with a JSON object containing: risk_label, confidence (0-1 float), reasoning (short sentence)." + ) + + # Retry DeepSeek inference with backoff + last_err: Optional[Exception] = None + for attempt in range(3): + try: + api_response = await _run_hf_blocking( + lambda model=CHAT_MODEL, prompt=risk_prompt: client.chat.completions.create( # type: ignore[arg-type] + model=model, + messages=[ + {"role": "system", "content": "You are a student risk analyst. Respond with valid JSON only."}, + {"role": "user", "content": prompt}, + ], + response_format={"type": "json_object"}, + max_tokens=256, + temperature=0.0, + ) + ) + last_err = None + break + except (APIError, RateLimitError, APITimeoutError, Exception) as api_err: + last_err = api_err + logger.warning(f"DeepSeek risk prediction attempt {attempt + 1} failed: {api_err}") + if attempt < 2: + await asyncio.sleep(2 ** attempt) + + if last_err is not None: + logger.error(f"DeepSeek risk prediction failed after 3 attempts: {last_err}") + raise HTTPException( + status_code=502, + detail="Risk prediction model is temporarily unavailable.", + ) + + content = api_response.choices[0].message.content or "{}" + try: + parsed = json.loads(content) + except json.JSONDecodeError: + parsed = {"risk_label": "medium academic risk", "confidence": 0.5} + + risk_label = str(parsed.get("risk_label", "medium academic risk")) + confidence = float(parsed.get("confidence", 0.5)) + + risk_level = RISK_MAPPING.get(risk_label, "Medium") + strict_risk_level = _to_strict_risk_level(risk_level) + top_factors = _basic_risk_top_factors(student_data) + + result = RiskPrediction( + riskLevel=risk_level, + confidence=round(confidence, 4), + analysis={ + "labels": [risk_label], + "scores": [round(confidence, 4)], + }, + risk_level=strict_risk_level, + risk_score=round(confidence, 4), + top_factors=top_factors, + ) + await deterministic_response_cache.set( + cache_key, + result.model_dump(), + PREDICT_RISK_CACHE_TTL_SECONDS, + ) + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Risk prediction error: {e}") + raise HTTPException(status_code=500, detail=f"Risk prediction error: {str(e)}") + + +@app.post("/api/predict-risk/batch") +async def predict_risk_batch(request: BatchRiskRequest): + """Batch risk prediction for multiple students""" + results = [] + for student in request.students: + try: + result = await predict_risk(student, Response()) + results.append(result) + except Exception: + results.append( + RiskPrediction( + riskLevel="Medium", + confidence=0.0, + analysis={"labels": [], "scores": []}, + risk_level="medium", + risk_score=0.0, + top_factors=["Fallback risk response due to prediction error"], + ) + ) + return results + + +# ─── Learning Path Generation ────────────────────────────────── + + +@app.post("/api/learning-path", response_model=LearningPathResponse) +async def generate_ai_learning_path(request: LearningPathRequest, response: Response): + """Generate AI-powered personalized learning path""" + try: + cache_key = deterministic_response_cache.build_cache_key( + "learning_path", + request.model_dump(), + ) + cached_payload = await deterministic_response_cache.get(cache_key) + if isinstance(cached_payload, dict): + _set_cache_response_header(response, hit=True) + return LearningPathResponse(**cached_payload) + + _set_cache_response_header(response, hit=False) + rag_context_block = "" + if ENABLE_RAG_ANALYSIS_CONTEXT: + try: + subject_for_context = (request.subject or "general_math").strip() or "general_math" + competency_chunks = build_analysis_curriculum_context(request.weaknesses, subject_for_context) + if competency_chunks: + lines = [] + for idx, row in enumerate(competency_chunks[:8], start=1): + lines.append( + f"{idx}. {row.get('content')} (Source: {row.get('source_file')} p.{row.get('page')}, " + f"Q{row.get('quarter')}, {row.get('content_domain')})" + ) + rag_context_block = ( + "RELEVANT DEPED LEARNING COMPETENCIES FOR WEAK TOPICS:\n" + + "\n".join(lines) + + "\n\n" + ) + except Exception as rag_err: + logger.warning(f"RAG analysis context skipped: {rag_err}") + + prompt = f"""Generate a personalized math learning path for a student with these details: +{rag_context_block}STUDENT PERFORMANCE DATA: +- Weak Topics: {', '.join(request.weaknesses)} +- Grade Level: {request.gradeLevel} +- Learning Style: {request.learningStyle or 'visual'} + +Create a structured learning path with 5-7 specific activities. For each activity provide: +1. Activity title +2. Brief description (1-2 sentences) +3. Estimated duration +4. Type (video, practice, quiz, reading, interactive) + +Format as a numbered list. Be specific to the math topics mentioned.""" + + messages = [ + { + "role": "system", + "content": "You are an educational curriculum expert specializing in mathematics. Create clear, actionable learning paths.", + }, + {"role": "user", "content": prompt}, + ] + + try: + content = await call_hf_chat_async( + messages, + max_tokens=1500, + temperature=0.7, + task_type="learning_path", + ) + except Exception as hf_err: + logger.error(f"HF learning-path failed: {hf_err}") + raise HTTPException( + status_code=502, + detail="Learning path generation is temporarily unavailable.", + ) + + result = LearningPathResponse(learningPath=content) + await deterministic_response_cache.set( + cache_key, + result.model_dump(), + LEARNING_PATH_CACHE_TTL_SECONDS, + ) + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Learning path error: {e}") + raise HTTPException(status_code=500, detail=f"Learning path error: {str(e)}") + + +# ─── Daily AI Insights ───────────────────────────────────────── + + +@app.post("/api/analytics/daily-insight", response_model=DailyInsightResponse) +async def daily_insight(request: DailyInsightRequest, response: Response): + """Generate daily AI insights for teacher dashboard""" + try: + cache_key = deterministic_response_cache.build_cache_key( + "daily_insight", + request.model_dump(), + ) + cached_payload = await deterministic_response_cache.get(cache_key) + if isinstance(cached_payload, dict): + _set_cache_response_header(response, hit=True) + return DailyInsightResponse(**cached_payload) + + _set_cache_response_header(response, hit=False) + students = request.students + total = len(students) + if total == 0: + empty_result = DailyInsightResponse(insight="No student data available for analysis.") + await deterministic_response_cache.set( + cache_key, + empty_result.model_dump(), + DAILY_INSIGHT_CACHE_TTL_SECONDS, + ) + return empty_result + + avg_engagement = sum(s.engagementScore for s in students) / total + avg_quiz = sum(s.avgQuizScore for s in students) / total + avg_attendance = sum(s.attendance for s in students) / total + high_risk = sum(1 for s in students if s.riskLevel == "High") + medium_risk = sum(1 for s in students if s.riskLevel == "Medium") + + prompt = f"""Analyze this classroom data and provide actionable insights for a math teacher: + +Classroom Summary: +- Total Students: {total} +- Average Engagement: {avg_engagement:.1f}% +- Average Quiz Score: {avg_quiz:.1f}% +- Average Attendance: {avg_attendance:.1f}% +- High-Risk Students: {high_risk} +- Medium-Risk Students: {medium_risk} +- Low-Risk Students: {total - high_risk - medium_risk} + +Provide: +1. A brief overall assessment (2-3 sentences) +2. 3-4 specific, actionable recommendations for the teacher +3. One positive observation to highlight + +Keep the response under 200 words. Be specific and practical.""" + + messages = [ + { + "role": "system", + "content": "You are an educational data analyst providing insights to math teachers. Be specific, actionable, and encouraging.", + }, + {"role": "user", "content": prompt}, + ] + + try: + content = await call_hf_chat_async( + messages, + max_tokens=800, + temperature=0.7, + task_type="daily_insight", + ) + except Exception as hf_err: + logger.error(f"HF daily-insight failed: {hf_err}") + raise HTTPException( + status_code=502, + detail="AI insight generation is temporarily unavailable.", + ) + + result = DailyInsightResponse(insight=content) + await deterministic_response_cache.set( + cache_key, + result.model_dump(), + DAILY_INSIGHT_CACHE_TTL_SECONDS, + ) + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Daily insight error: {e}") + raise HTTPException(status_code=500, detail=f"Daily insight error: {str(e)}") + + +# ─── Smart Document Upload ──────────────────────────────────── + + +CLASS_RECORD_REQUIRED_FIELDS: Set[str] = { + "name", + "lrn", + "email", + "engagementScore", + "avgQuizScore", + "attendance", + "assignmentCompletion", + "term", + "assessmentName", +} + +# Fields that must be present in mapped columns before import proceeds. +# These are the minimum metrics required for downstream dashboard/risk workflows. +CLASS_RECORD_IMPORT_CORE_FIELDS: Set[str] = { + "name", + "engagementScore", + "avgQuizScore", + "attendance", +} + +CLASS_RECORD_IDENTITY_FIELDS: Set[str] = {"lrn", "email"} + +CLASS_RECORD_SCORING_FIELDS: Set[str] = { + "engagementScore", + "avgQuizScore", + "attendance", + "assignmentCompletion", +} + +SUPPORTED_DATASET_INTENTS: Set[str] = { + "synthetic_student_records", + "general_analytics", + "eval_only", +} + +NON_EDUCATION_DOMAIN_TOKENS: Dict[str, Set[str]] = { + "medical": {"patient", "diagnosis", "disease", "symptom", "clinical", "medication"}, + "finance": {"loan", "mortgage", "revenue", "income", "profit", "stock", "price"}, + "real_estate": {"bedroom", "bathroom", "sqft", "square feet", "zipcode", "property", "house"}, + "genomics": {"gene", "genome", "protein", "rna", "dna", "mutation", "chromosome"}, +} + +CLASS_RECORD_FIELD_ALIASES: Dict[str, List[str]] = { + "name": ["name", "student", "learner", "fullname", "full name"], + "lrn": ["lrn", "student id", "learner id", "reference", "id number"], + "email": ["email", "e-mail", "mail"], + "engagementScore": ["engagement", "participation", "activity", "involvement"], + "avgQuizScore": ["quiz", "score", "grade", "exam", "test", "assessment"], + "attendance": ["attendance", "present", "absence", "attend"], + "assignmentCompletion": ["assignment", "homework", "submission", "completion", "task"], + "term": ["term", "quarter", "semester", "period", "grading"], + "assessmentName": ["assessment", "exam", "quiz name", "test name", "activity name", "title"], +} + + +def _is_empty_cell(value: Any) -> bool: + if value is None: + return True + if isinstance(value, float): + return math.isnan(value) + if isinstance(value, str): + return not value.strip() + return False + + +def _stringify_cell(value: Any) -> str: + if _is_empty_cell(value): + return "" + return str(value).strip() + + +def _normalize_unknown_key(column_name: str) -> str: + base = re.sub(r"[^a-z0-9]+", "_", column_name.lower()).strip("_") + if not base: + base = "field" + return f"unknown_{base[:48]}" + + +def _safe_numeric(value: Any, *, default_value: float = 0.0) -> Tuple[float, Optional[str]]: + if _is_empty_cell(value): + return default_value, "missing numeric value; defaulted to 0" + + raw = str(value).strip().replace("%", "").replace(",", "") + try: + return float(raw), None + except Exception: + return default_value, f"invalid numeric value '{raw}'; defaulted to 0" + + +def _normalize_column_text(value: str) -> str: + return re.sub(r"[^a-z0-9]+", " ", str(value or "").strip().lower()).strip() + + +def _detect_non_education_signals(columns: List[str]) -> Dict[str, List[str]]: + matches: Dict[str, List[str]] = {} + for raw_column in columns: + normalized = _normalize_column_text(raw_column) + if not normalized: + continue + hit_domains: List[str] = [] + for domain, tokens in NON_EDUCATION_DOMAIN_TOKENS.items(): + if any(token in normalized for token in tokens): + hit_domains.append(domain) + if hit_domains: + matches[raw_column] = sorted(set(hit_domains)) + return matches + + +def _build_column_interpretations( + *, + columns: List[str], + mapping: Dict[str, str], + mapping_source: Dict[str, str], + non_education_matches: Dict[str, List[str]], +) -> List[Dict[str, Any]]: + interpretations: List[Dict[str, Any]] = [] + for col in columns: + mapped_field = mapping.get(col) + source = mapping_source.get(col, "unmapped") + domains = non_education_matches.get(col, []) + + if mapped_field in CLASS_RECORD_SCORING_FIELDS: + usage_policy = "scoring" + reason = "Mapped to a core educational metric used in risk and dashboard calculations." + elif mapped_field in CLASS_RECORD_REQUIRED_FIELDS: + usage_policy = "display" + reason = "Mapped to an educational identity/context field used for display and record management." + else: + usage_policy = "storage_only" + reason = "Column is not mapped to a supported educational field and is excluded from scoring." + + if source == "fallback": + confidence_band = "high" + elif source == "ai": + confidence_band = "medium" + else: + confidence_band = "low" + + if domains: + confidence_band = "low" + reason = f"Detected non-education domain signals ({', '.join(domains)}); kept as storage-only metadata." + usage_policy = "storage_only" + + interpretations.append( + { + "columnName": col, + "mappedField": mapped_field, + "mappingSource": source, + "confidenceBand": confidence_band, + "usagePolicy": usage_policy, + "reason": reason, + "domainSignals": domains, + } + ) + + return interpretations + + +def _validate_class_record_mapping(mapping: Dict[str, str]) -> Tuple[List[str], List[str]]: + mapped_fields = set(mapping.values()) + missing_core_fields = sorted(CLASS_RECORD_IMPORT_CORE_FIELDS - mapped_fields) + has_identity = bool(CLASS_RECORD_IDENTITY_FIELDS & mapped_fields) + missing_identity = [] if has_identity else ["lrn_or_email"] + return missing_core_fields, missing_identity + + +def _build_scoring_student_payload(row: Dict[str, Any], class_section_id: Optional[str]) -> Dict[str, Any]: + return { + "studentId": row.get("studentId"), + "name": row.get("name"), + "email": row.get("email"), + "lrn": row.get("lrn"), + "avgQuizScore": row.get("avgQuizScore"), + "attendance": row.get("attendance"), + "engagementScore": row.get("engagementScore"), + "assignmentCompletion": row.get("assignmentCompletion"), + "term": row.get("term"), + "assessmentName": row.get("assessmentName"), + "classSectionId": class_section_id, + } + + +def _fallback_column_mapping(columns: List[str]) -> Dict[str, str]: + mapping: Dict[str, str] = {} + for col in columns: + normalized = re.sub(r"[^a-z0-9]+", " ", col.lower()).strip() + if not normalized: + continue + + for field, aliases in CLASS_RECORD_FIELD_ALIASES.items(): + if any(alias in normalized for alias in aliases): + if col not in mapping: + mapping[col] = field + break + + return mapping + + +def _sanitize_column_mapping(raw_mapping: Any) -> Dict[str, str]: + """Keep only non-empty string->known-field entries from AI mapping output.""" + if not isinstance(raw_mapping, dict): + return {} + + sanitized: Dict[str, str] = {} + for raw_col, raw_field in raw_mapping.items(): + column_name = _stringify_cell(raw_col) + mapped_field = _stringify_cell(raw_field) + if not column_name or not mapped_field: + continue + if mapped_field not in CLASS_RECORD_REQUIRED_FIELDS: + continue + sanitized[column_name] = mapped_field + + return sanitized + + +def _build_record_identity( + student: Dict[str, Any], + unknown_fields: Dict[str, Any], +) -> Tuple[str, str, str, str]: + candidate_student_id = ( + str(student.get("lrn", "")).strip() + or str(student.get("email", "")).strip().lower() + or re.sub(r"\s+", "_", str(student.get("name", "")).strip().lower()) + ) + if not candidate_student_id: + candidate_student_id = "unknown-student" + + term = str(student.get("term", "")).strip() + if not term: + for key, value in unknown_fields.items(): + if any(token in key for token in ("term", "quarter", "semester", "period")): + term = str(value).strip() + break + if not term: + term = "unspecified-term" + + assessment_name = str(student.get("assessmentName", "")).strip() + if not assessment_name: + for key, value in unknown_fields.items(): + if any(token in key for token in ("assessment", "exam", "quiz", "test", "activity", "title")): + assessment_name = str(value).strip() + break + if not assessment_name: + assessment_name = "general-assessment" + + dedup_seed = f"{candidate_student_id.lower()}|{term.lower()}|{assessment_name.lower()}" + dedup_key = hashlib.sha1(dedup_seed.encode("utf-8")).hexdigest()[:28] + return candidate_student_id, term, assessment_name, dedup_key + + +def _persist_class_record_import_artifact( + request: Request, + *, + file_hash: str, + file_name: str, + file_type: str, + column_mapping: Dict[str, str], + normalized_rows: List[Dict[str, Any]], + row_warnings: List[Dict[str, Any]], + unknown_columns: List[str], + parse_warnings: List[str], + dataset_intent: str, + column_interpretations: List[Dict[str, Any]], + interpretation_summary: Dict[str, Any], + class_section_id: Optional[str] = None, + class_name: Optional[str] = None, +) -> Dict[str, Any]: + if not (_firebase_ready and firebase_firestore): + return { + "persisted": False, + "importId": None, + "dedup": {"inserted": 0, "updated": 0}, + "warning": "Firestore unavailable; class records were not persisted.", + } + + user = get_current_user(request) + normalized_class_section_id = (class_section_id or "").strip() or None + normalized_class_name = (class_name or "").strip() or None + import_seed = f"{user.uid}|{normalized_class_section_id or 'global'}|{file_hash}" + import_id = hashlib.sha1(import_seed.encode("utf-8")).hexdigest()[:28] + + import_payload: Dict[str, Any] = { + "importId": import_id, + "teacherId": user.uid, + "teacherEmail": user.email, + "fileName": file_name, + "fileType": file_type, + "fileHash": file_hash, + "rowCount": len(normalized_rows), + "columnMapping": column_mapping, + "unknownColumns": unknown_columns, + "parseWarnings": parse_warnings, + "datasetIntent": dataset_intent, + "columnInterpretations": column_interpretations, + "interpretationSummary": interpretation_summary, + "rowWarnings": row_warnings[:300], + "source": "api_upload_class_records", + "retentionDays": IMPORT_RETENTION_DAYS, + "expiresAtEpoch": _artifact_expiry_epoch(), + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if normalized_class_section_id: + import_payload["classSectionId"] = normalized_class_section_id + if normalized_class_name: + import_payload["className"] = normalized_class_name + + inserted = 0 + updated = 0 + try: + imports_ref = firebase_firestore.client().collection("classRecordImports").document(import_id) + import_doc = cast(Any, imports_ref.get()) + if not _snapshot_exists(import_doc): + import_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + imports_ref.set(import_payload, merge=True) + + client = firebase_firestore.client() + normalized_ref = client.collection("normalizedClassRecords") + batch = client.batch() + batch_count = 0 + + for row in normalized_rows: + dedup_key = str(row.get("dedupKey", "")).strip() + if not dedup_key: + continue + + scoped_key_seed = f"{user.uid}|{normalized_class_section_id or 'global'}|{dedup_key}" + scoped_key = hashlib.sha1(scoped_key_seed.encode("utf-8")).hexdigest()[:36] + row_doc_ref = normalized_ref.document(scoped_key) + existing_doc = cast(Any, row_doc_ref.get()) + + payload = { + **row, + "recordId": scoped_key, + "teacherId": user.uid, + "teacherEmail": user.email, + "importId": import_id, + "sourceFile": file_name, + "retentionDays": IMPORT_RETENTION_DAYS, + "expiresAtEpoch": _artifact_expiry_epoch(), + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if normalized_class_section_id: + payload["classSectionId"] = normalized_class_section_id + if normalized_class_name: + payload["className"] = normalized_class_name + if not _snapshot_exists(existing_doc): + payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + inserted += 1 + else: + updated += 1 + + batch.set(row_doc_ref, payload, merge=True) + batch_count += 1 + if batch_count >= 400: + batch.commit() + batch = client.batch() + batch_count = 0 + + if batch_count > 0: + batch.commit() + except Exception as persistence_err: + if _is_adc_missing_error(cast(Exception, persistence_err)): + logger.warning("Class record persistence skipped because Firestore ADC is not configured.") + return { + "persisted": False, + "importId": None, + "dedup": {"inserted": 0, "updated": 0}, + "warning": ( + "Firestore ADC is not configured; class records were parsed but not persisted. " + "Set FIREBASE_SERVICE_ACCOUNT_JSON or FIREBASE_SERVICE_ACCOUNT_FILE, " + "or set GOOGLE_APPLICATION_CREDENTIALS." + ), + } + raise + + return { + "persisted": True, + "importId": import_id, + "dedup": {"inserted": inserted, "updated": updated}, + "warning": None, + } + + +def _normalize_class_records( + df: Any, + *, + file_name: str, + file_hash: str, + column_mapping: Dict[str, str], +) -> Dict[str, Any]: + normalized_rows: List[Dict[str, Any]] = [] + row_warnings: List[Dict[str, Any]] = [] + rejected_rows: List[Dict[str, Any]] = [] + unknown_columns = sorted([col for col in df.columns if col not in column_mapping]) + inferred_rows = 0 + fallback_inference_rows = 0 + + for idx, row in df.iterrows(): + student: Dict[str, Any] = {} + unknown_fields: Dict[str, Any] = {} + warnings_for_row: List[str] = [] + + for col in df.columns: + raw_value = row[col] + mapped_field = column_mapping.get(col) + if mapped_field in CLASS_RECORD_REQUIRED_FIELDS: + student[mapped_field] = _stringify_cell(raw_value) + else: + text_val = _stringify_cell(raw_value) + if text_val: + unknown_fields[_normalize_unknown_key(col)] = text_val + + student_name = str(student.get("name", "")).strip() + if not student_name: + rejected_rows.append( + { + "row": int(idx) + 2, + "reason": "missing required field: name", + } + ) + continue + + lrn_value = str(student.get("lrn", "")).strip() + email_value = str(student.get("email", "")).strip().lower() + if not lrn_value and not email_value: + rejected_rows.append( + { + "row": int(idx) + 2, + "reason": "missing required identity value: lrn_or_email", + } + ) + continue + + defaulted_metrics: Set[str] = set() + for field in ["engagementScore", "avgQuizScore", "attendance", "assignmentCompletion"]: + numeric_value, parse_warning = _safe_numeric(student.get(field)) + student[field] = numeric_value + if parse_warning: + warnings_for_row.append(f"{field}: {parse_warning}") + defaulted_metrics.add(field) + + student_id, term, assessment_name, dedup_key = _build_record_identity(student, unknown_fields) + student["name"] = student_name + student["email"] = email_value + student["lrn"] = lrn_value + student["term"] = term + student["assessmentName"] = assessment_name + + has_topic_signal = bool( + (assessment_name and assessment_name.lower() != "general-assessment") + or any( + any(token in str(key).lower() for token in ("topic", "unit", "lesson", "skill", "competency")) + for key in unknown_fields.keys() + ) + ) + inference = _infer_student_state( + avg_quiz=float(student.get("avgQuizScore") or 0.0), + attendance=float(student.get("attendance") or 0.0), + engagement=float(student.get("engagementScore") or 0.0), + defaulted_metrics=defaulted_metrics, + has_topic_signal=has_topic_signal, + ) + inferred_rows += 1 + if inference.get("fallbackUsed"): + fallback_inference_rows += 1 + + normalized_row = { + **student, + "unknownFields": unknown_fields, + "sourceMeta": { + "fileName": file_name, + "fileHash": file_hash, + "sourceRow": int(idx) + 2, + }, + "studentId": student_id, + "dedupKey": dedup_key, + "riskLevel": inference["riskLevel"], + "inferredState": { + "state": inference["state"], + "confidence": inference["confidence"], + "signals": inference["signals"], + "explanation": inference["explanation"], + "fallbackUsed": inference["fallbackUsed"], + }, + } + + normalized_rows.append(normalized_row) + if warnings_for_row: + row_warnings.append( + { + "row": int(idx) + 2, + "warning": "; ".join(warnings_for_row), + } + ) + + return { + "rows": normalized_rows, + "rowWarnings": row_warnings, + "rejectedRows": rejected_rows, + "unknownColumns": unknown_columns, + "interpretedRows": len(normalized_rows), + "rejectedRowsCount": len(rejected_rows), + "rejectedReasonCounts": dict(Counter(item["reason"] for item in rejected_rows)), + "inferredRows": inferred_rows, + "fallbackInferenceRows": fallback_inference_rows, + } + + +def _slugify_class_token(value: str) -> str: + token = re.sub(r"[^a-z0-9]+", "_", str(value or "").strip().lower()) + return re.sub(r"_+", "_", token).strip("_") + + +def _normalize_grade_level(raw_grade: Optional[str]) -> str: + text = (raw_grade or "").strip() + if not text: + return "Grade 11" + + number_match = re.search(r"(\d{1,2})", text) + if number_match: + return f"Grade {number_match.group(1)}" + + if text.lower().startswith("grade"): + return re.sub(r"\s+", " ", text).strip().replace("grade", "Grade", 1) + + return text + + +def _infer_classification(grade_level: Optional[str]) -> str: + level = _normalize_grade_level(grade_level) + grade_match = re.search(r"(\d{1,2})", level) + if grade_match: + grade_number = int(grade_match.group(1)) + return "Senior High School" if grade_number >= 11 else "Junior High School" + return "Senior High School" + + +def _infer_strand(*, class_name: Optional[str], section: Optional[str]) -> Optional[str]: + source = f"{class_name or ''} {section or ''}".upper() + if not source.strip(): + return None + + for token in ("STEM", "ABM", "HUMSS", "GAS", "TVL", "ICT"): + if re.search(rf"\b{token}\b", source): + return token + + return None + + +def _build_class_metadata( + *, + class_section_id: Optional[str], + class_name: Optional[str], + grade: Optional[str], + section: Optional[str], + school_year: Optional[str] = None, + owner_teacher_id: Optional[str] = None, + owner_teacher_name: Optional[str] = None, + adviser_teacher_id: Optional[str] = None, + adviser_teacher_name: Optional[str] = None, + manager_id: Optional[str] = None, + manager_name: Optional[str] = None, + classification: Optional[str] = None, + strand: Optional[str] = None, + grade_level: Optional[str] = None, +) -> Dict[str, Any]: + normalized_grade = (grade or "").strip() or None + normalized_section = (section or "").strip() or None + normalized_grade_level = _normalize_grade_level(grade_level or normalized_grade) + normalized_classification = (classification or "").strip() or _infer_classification(normalized_grade_level) + normalized_strand = (strand or "").strip() or _infer_strand(class_name=class_name, section=normalized_section) + + return { + "classSectionId": (class_section_id or "").strip() or None, + "className": (class_name or "").strip() or None, + "grade": normalized_grade, + "section": normalized_section, + "gradeLevel": normalized_grade_level, + "classification": normalized_classification, + "strand": normalized_strand, + "schoolYear": (school_year or "").strip() or str(datetime.now(timezone.utc).year), + "ownerTeacherId": (owner_teacher_id or "").strip() or None, + "ownerTeacherName": (owner_teacher_name or "").strip() or None, + "adviserTeacherId": (adviser_teacher_id or "").strip() or None, + "adviserTeacherName": (adviser_teacher_name or "").strip() or None, + "managerId": (manager_id or "").strip() or None, + "managerName": (manager_name or "").strip() or None, + } + + +def _resolve_import_class_context( + *, + class_section_id: Optional[str], + class_name: Optional[str], +) -> Tuple[str, str, str, str]: + normalized_section_id = (class_section_id or "").strip().lower() + normalized_class_name = (class_name or "").strip() + + grade = "Grade 11" + section = "Section A" + + if normalized_class_name and " - " in normalized_class_name: + parts = [part.strip() for part in normalized_class_name.split(" - ", 1)] + if parts[0]: + grade = parts[0] + if len(parts) > 1 and parts[1]: + section = parts[1] + + if normalized_section_id: + section_tokens = [token for token in normalized_section_id.split("_") if token] + if section_tokens: + if section_tokens[0].startswith("grade"): + suffix = section_tokens[0].replace("grade", "").strip("_") + if suffix: + grade = f"Grade {suffix}" + elif section_tokens[0].isdigit(): + grade = f"Grade {section_tokens[0]}" + elif "grade" in section_tokens[0]: + grade = section_tokens[0].replace("_", " ").title() + if len(section_tokens) > 1: + section = " ".join(token.capitalize() for token in section_tokens[1:]) + elif len(section_tokens) == 1 and section_tokens[0] and section_tokens[0] != "grade": + section = section or "Section A" + + if not normalized_section_id: + normalized_section_id = _slugify_class_token(f"{grade}_{section}") or "grade_11_section_a" + + resolved_name = normalized_class_name or f"{grade} - {section}" + return normalized_section_id, resolved_name, grade, section + + +def _derive_risk_level(avg_quiz: float, attendance: float, engagement: float) -> str: + if avg_quiz < 60 or engagement < 55: + return "High" + if avg_quiz < 75 or engagement < 70: + return "Medium" + return "Low" + + +def _infer_student_state( + *, + avg_quiz: float, + attendance: float, + engagement: float, + defaulted_metrics: Set[str], + has_topic_signal: bool, +) -> Dict[str, Any]: + risk_level = _derive_risk_level(avg_quiz, attendance, engagement) + + if risk_level == "High": + if avg_quiz < 50 or engagement < 45: + state = "urgent_intervention" + else: + state = "at_risk" + elif risk_level == "Medium": + state = "watchlist" + else: + state = "on_track" + + relevant_metrics = {"avgQuizScore", "engagementScore"} + non_default_metrics = len(relevant_metrics - defaulted_metrics) + completeness = max(0.0, min(1.0, float(non_default_metrics) / float(len(relevant_metrics)))) + + threshold_gap = max( + max(0.0, 60.0 - avg_quiz) / 60.0, + max(0.0, 55.0 - engagement) / 55.0, + ) + confidence = 0.45 + (0.45 * completeness) + (0.1 * min(1.0, threshold_gap)) + confidence = max(0.1, min(0.99, confidence)) + + signals: List[str] = [] + if avg_quiz < 60: + signals.append("low_avg_quiz_score") + if engagement < 55: + signals.append("low_engagement") + if defaulted_metrics: + signals.append("fallback_defaulted_metrics") + if not has_topic_signal: + signals.append("fallback_general_topic_context") + if not signals: + signals.append("stable_core_metrics") + + explanation = ( + f"State inferred from avgQuizScore={avg_quiz:.1f}, engagementScore={engagement:.1f}." + ) + + return { + "riskLevel": risk_level, + "state": state, + "confidence": round(confidence, 3), + "signals": signals, + "explanation": explanation, + "fallbackUsed": bool(defaulted_metrics or not has_topic_signal), + } + + +def _pick_weakest_topic(unknown_fields: Dict[str, Any]) -> str: + for key, value in (unknown_fields or {}).items(): + if any(token in key for token in ("weak", "topic", "skill", "competency")): + topic = _canonicalize_topic_label(str(value or "").strip()) + if topic: + return topic + return "Foundational Skills" + + +def _sync_imported_students_to_teacher_dashboard( + request: Request, + *, + normalized_rows: List[Dict[str, Any]], + class_section_id: Optional[str], + class_name: Optional[str], +) -> Dict[str, Any]: + fallback_class_section_id = (class_section_id or "").strip().lower() or None + fallback_class_name = (class_name or "").strip() or None + fallback_grade: Optional[str] = None + fallback_section: Optional[str] = None + if fallback_class_section_id or fallback_class_name: + _, _, fallback_grade, fallback_section = _resolve_import_class_context( + class_section_id=fallback_class_section_id, + class_name=fallback_class_name, + ) + + fallback_class_metadata = _build_class_metadata( + class_section_id=fallback_class_section_id, + class_name=fallback_class_name, + grade=fallback_grade, + section=fallback_section, + ) + + if not (_firebase_ready and firebase_firestore): + return { + "synced": False, + "createdStudents": 0, + "updatedStudents": 0, + "classroomsTouched": 0, + "classroomId": None, + "classSectionId": fallback_class_section_id, + "className": fallback_class_name, + "classMetadata": fallback_class_metadata, + "warning": "Firestore unavailable; dashboard sync skipped.", + } + + try: + user = get_current_user(request) + resolved_section_id, resolved_class_name, grade, section = _resolve_import_class_context( + class_section_id=class_section_id, + class_name=class_name, + ) + resolved_class_metadata = _build_class_metadata( + class_section_id=resolved_section_id, + class_name=resolved_class_name, + grade=grade, + section=section, + owner_teacher_id=user.uid, + owner_teacher_name=user.email, + adviser_teacher_id=user.uid, + adviser_teacher_name=user.email, + manager_id=user.uid, + manager_name=user.email, + ) + + by_identity: Dict[str, Dict[str, Any]] = defaultdict(dict) + for row in normalized_rows: + name = str(row.get("name") or "").strip() + email = str(row.get("email") or "").strip().lower() + lrn = str(row.get("lrn") or "").strip() + identity = lrn or email or re.sub(r"\s+", "_", name.lower()) + if not identity: + continue + + state = by_identity.get(identity) + if not state: + state = { + "name": name, + "email": email, + "lrn": lrn, + "scores": [], + "attendance": [], + "engagement": [], + "completion": [], + "weakestTopic": "", + } + by_identity[identity] = state + + avg_quiz = float(row.get("avgQuizScore") or 0.0) + attendance = float(row.get("attendance") or 0.0) + engagement = float(row.get("engagementScore") or 0.0) + completion = float(row.get("assignmentCompletion") or 0.0) + + state["scores"].append(avg_quiz) + state["attendance"].append(attendance) + state["engagement"].append(engagement) + state["completion"].append(completion) + if not state.get("weakestTopic"): + state["weakestTopic"] = _pick_weakest_topic(row.get("unknownFields") or {}) + + if name and not state.get("name"): + state["name"] = name + if email and not state.get("email"): + state["email"] = email + if lrn and not state.get("lrn"): + state["lrn"] = lrn + + if not by_identity: + return { + "synced": False, + "createdStudents": 0, + "updatedStudents": 0, + "classroomsTouched": 0, + "classroomId": resolved_section_id, + "classSectionId": resolved_section_id, + "className": resolved_class_name, + "classMetadata": resolved_class_metadata, + "warning": "No normalized student records were available for dashboard sync.", + } + + client = firebase_firestore.client() + classrooms_ref = client.collection("classrooms") + students_ref = client.collection("managedStudents") + + classroom_doc_id = resolved_section_id or f"imported_{hashlib.sha1((user.uid + resolved_class_name).encode('utf-8')).hexdigest()[:12]}" + classroom_ref = classrooms_ref.document(classroom_doc_id) + classroom_snapshot = cast(Any, classroom_ref.get()) + existing_classroom = _snapshot_to_dict(classroom_snapshot) if _snapshot_exists(classroom_snapshot) else {} + + created_students = 0 + updated_students = 0 + risk_high_count = 0 + total_score = 0.0 + + for identity, aggregate in by_identity.items(): + scores = aggregate.get("scores") or [0.0] + attendance_values = aggregate.get("attendance") or [0.0] + engagement_values = aggregate.get("engagement") or [0.0] + completion_values = aggregate.get("completion") or [0.0] + + avg_quiz = float(sum(scores) / max(len(scores), 1)) + avg_attendance = float(sum(attendance_values) / max(len(attendance_values), 1)) + avg_engagement = float(sum(engagement_values) / max(len(engagement_values), 1)) + avg_completion = float(sum(completion_values) / max(len(completion_values), 1)) + + has_topic_signal = str(aggregate.get("weakestTopic") or "").strip() not in {"", "Foundational Skills"} + inference = _infer_student_state( + avg_quiz=avg_quiz, + attendance=avg_attendance, + engagement=avg_engagement, + defaulted_metrics=set(), + has_topic_signal=has_topic_signal, + ) + risk_level = str(inference["riskLevel"]) + if risk_level == "High": + risk_high_count += 1 + total_score += avg_quiz + + fallback_name = str(aggregate.get("name") or "Imported Student").strip() or "Imported Student" + avatar_seed = urllib.parse.quote(fallback_name) + student_doc_id = hashlib.sha1(f"{user.uid}|{identity}".encode("utf-8")).hexdigest()[:36] + student_ref = students_ref.document(student_doc_id) + student_snapshot = cast(Any, student_ref.get()) + + payload: Dict[str, Any] = { + "teacherId": user.uid, + "name": fallback_name, + "email": str(aggregate.get("email") or ""), + "lrn": str(aggregate.get("lrn") or ""), + "avatar": f"https://ui-avatars.com/api/?name={avatar_seed}&background=random", + "grade": grade, + "gradeLevel": resolved_class_metadata.get("gradeLevel"), + "classification": resolved_class_metadata.get("classification"), + "strand": resolved_class_metadata.get("strand"), + "section": section, + "classSectionId": resolved_section_id, + "classroomId": classroom_doc_id, + "className": resolved_class_name, + "managerId": resolved_class_metadata.get("managerId"), + "managerName": resolved_class_metadata.get("managerName"), + "classMetadata": resolved_class_metadata, + "riskLevel": risk_level, + "inferredState": { + "state": inference["state"], + "confidence": inference["confidence"], + "signals": inference["signals"], + "explanation": inference["explanation"], + "fallbackUsed": inference["fallbackUsed"], + }, + "stateConfidence": inference["confidence"], + "stateSignals": inference["signals"], + "engagementScore": round(avg_engagement, 1), + "avgQuizScore": round(avg_quiz, 1), + "weakestTopic": str(aggregate.get("weakestTopic") or "Foundational Skills"), + "attendance": round(avg_attendance, 1), + "assignmentCompletion": round(avg_completion, 1), + "struggles": [str(aggregate.get("weakestTopic") or "Foundational Skills")], + "lastActive": FIRESTORE_SERVER_TIMESTAMP, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + + if _snapshot_exists(student_snapshot): + updated_students += 1 + else: + created_students += 1 + payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + + student_ref.set(payload, merge=True) + + student_count = len(by_identity) + class_average = round(total_score / max(student_count, 1), 1) + + classroom_payload: Dict[str, Any] = { + "teacherId": user.uid, + "name": resolved_class_name, + "grade": grade, + "gradeLevel": resolved_class_metadata.get("gradeLevel"), + "classification": resolved_class_metadata.get("classification"), + "strand": resolved_class_metadata.get("strand"), + "section": section, + "classSectionId": resolved_section_id, + "ownerTeacherId": user.uid, + "ownerTeacherName": user.email, + "adviserTeacherId": user.uid, + "adviserTeacherName": user.email, + "managerId": resolved_class_metadata.get("managerId"), + "managerName": resolved_class_metadata.get("managerName"), + "classMetadata": resolved_class_metadata, + "schedule": str(existing_classroom.get("schedule") or "Mon-Fri"), + "studentCount": student_count, + "avgScore": class_average, + "atRiskCount": risk_high_count, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if not _snapshot_exists(classroom_snapshot): + classroom_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + + classroom_ref.set(classroom_payload, merge=True) + + return { + "synced": True, + "createdStudents": created_students, + "updatedStudents": updated_students, + "classroomsTouched": 1, + "classroomId": classroom_doc_id, + "classSectionId": resolved_section_id, + "className": resolved_class_name, + "classMetadata": resolved_class_metadata, + "warning": None, + } + except Exception as sync_err: + if _is_adc_missing_error(cast(Exception, sync_err)): + logger.warning("Dashboard sync skipped because Firestore ADC is not configured.") + warning = ( + "Firestore ADC is not configured; dashboard sync skipped. " + "Set FIREBASE_SERVICE_ACCOUNT_JSON or FIREBASE_SERVICE_ACCOUNT_FILE, " + "or set GOOGLE_APPLICATION_CREDENTIALS." + ) + else: + logger.warning(f"Dashboard sync skipped due Firestore error: {sync_err}") + warning = f"Dashboard sync skipped due Firestore error: {sync_err}" + return { + "synced": False, + "createdStudents": 0, + "updatedStudents": 0, + "classroomsTouched": 0, + "classroomId": None, + "classSectionId": fallback_class_section_id, + "className": fallback_class_name, + "classMetadata": fallback_class_metadata, + "warning": warning, + } + + +def _resolve_uploaded_files( + *, + file: Optional[UploadFile], + files: Optional[List[UploadFile]], + max_files: int = UPLOAD_MAX_FILES_PER_REQUEST, +) -> List[UploadFile]: + resolved: List[UploadFile] = [] + if files: + resolved.extend(files) + if file is not None: + # Keep backward compatibility for clients sending a single `file` field. + resolved.append(file) + + unique_files: List[UploadFile] = [] + seen_keys: Set[Tuple[str, Optional[str]]] = set() + for upload in resolved: + key = ((upload.filename or "").strip(), upload.content_type) + if key in seen_keys: + continue + seen_keys.add(key) + unique_files.append(upload) + + if not unique_files: + raise HTTPException(status_code=400, detail="At least one file is required") + if len(unique_files) > max_files: + raise HTTPException(status_code=400, detail=f"Too many files. Max allowed per request: {max_files}") + return unique_files + + +def _artifact_expiry_epoch() -> int: + return int(time.time()) + (IMPORT_RETENTION_DAYS * 24 * 60 * 60) + + +def _is_artifact_expired(data: Dict[str, Any]) -> bool: + raw = data.get("expiresAtEpoch") + if raw is None: + return False + try: + return int(raw) <= int(time.time()) + except Exception: + return False + + +def _write_access_audit_log( + request: Request, + *, + action: str, + status: str, + class_section_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + if not (_firebase_ready and firebase_firestore): + return + + try: + user = get_current_user(request) + payload: Dict[str, Any] = { + "action": action, + "status": status, + "teacherId": user.uid, + "teacherEmail": user.email, + "role": user.role, + "path": request.url.path, + "method": request.method, + "createdAt": FIRESTORE_SERVER_TIMESTAMP, + "createdAtIso": datetime.now(timezone.utc).isoformat(), + } + normalized_class_section_id = (class_section_id or "").strip() or None + if normalized_class_section_id: + payload["classSectionId"] = normalized_class_section_id + if metadata: + payload["metadata"] = metadata + + firebase_firestore.client().collection("accessAuditLogs").add(payload) + except Exception as audit_err: + if _is_adc_missing_error(cast(Exception, audit_err)): + _warn_firestore_audit_lookup_once( + f"Access audit log skipped because Firestore ADC is not configured ({action})." + ) + else: + logger.warning(f"Access audit log failed ({action}): {audit_err}") + + +def _record_risk_refresh_event( + *, + teacher_id: str, + teacher_email: Optional[str], + class_section_id: Optional[str], + refresh_id: str, + status: str, + students_queued: int, + queued_at_epoch: int, + started_at_epoch: Optional[int] = None, + completed_at_epoch: Optional[int] = None, + duration_ms: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """Persist lightweight monitoring artifacts for queued risk refresh jobs.""" + if not (_firebase_ready and firebase_firestore): + return + + try: + client = firebase_firestore.client() + now_iso = datetime.now(timezone.utc).isoformat() + normalized_class_section_id = (class_section_id or "").strip() or None + + event_payload: Dict[str, Any] = { + "refreshId": refresh_id, + "status": status, + "teacherId": teacher_id, + "teacherEmail": teacher_email, + "studentsQueued": students_queued, + "queuedAtEpoch": queued_at_epoch, + "startedAtEpoch": started_at_epoch, + "completedAtEpoch": completed_at_epoch, + "durationMs": duration_ms, + "createdAt": FIRESTORE_SERVER_TIMESTAMP, + "createdAtIso": now_iso, + } + if normalized_class_section_id: + event_payload["classSectionId"] = normalized_class_section_id + if metadata: + event_payload["metadata"] = metadata + client.collection("riskRefreshEvents").add(event_payload) + + job_ref = client.collection("riskRefreshJobs").document(refresh_id) + job_update: Dict[str, Any] = { + "refreshId": refresh_id, + "status": status, + "teacherId": teacher_id, + "teacherEmail": teacher_email, + "studentsQueued": students_queued, + "queuedAtEpoch": queued_at_epoch, + "startedAtEpoch": started_at_epoch, + "completedAtEpoch": completed_at_epoch, + "durationMs": duration_ms, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + "updatedAtIso": now_iso, + } + if normalized_class_section_id: + job_update["classSectionId"] = normalized_class_section_id + if metadata: + job_update["metadata"] = metadata + + existing_job = cast(Any, job_ref.get()) + if not _snapshot_exists(existing_job): + job_update["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + job_update["createdAtIso"] = now_iso + job_ref.set(job_update, merge=True) + + stats_ref = client.collection("riskRefreshStats").document(teacher_id) + stats_doc = cast(Any, stats_ref.get()) + stats_data = _snapshot_to_dict(stats_doc) if _snapshot_exists(stats_doc) else {} + queued_count = int(stats_data.get("queuedCount", 0) or 0) + success_count = int(stats_data.get("successCount", 0) or 0) + failed_count = int(stats_data.get("failedCount", 0) or 0) + + if status == "queued": + queued_count += 1 + elif status == "success": + success_count += 1 + elif status == "failed": + failed_count += 1 + + stats_payload: Dict[str, Any] = { + "teacherId": teacher_id, + "teacherEmail": teacher_email, + "queuedCount": queued_count, + "successCount": success_count, + "failedCount": failed_count, + "lastRefreshId": refresh_id, + "lastStatus": status, + "lastStudentsQueued": students_queued, + "lastQueuedAtEpoch": queued_at_epoch, + "lastStartedAtEpoch": started_at_epoch, + "lastCompletedAtEpoch": completed_at_epoch, + "lastDurationMs": duration_ms, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + "updatedAtIso": now_iso, + } + if normalized_class_section_id: + stats_payload["classSectionId"] = normalized_class_section_id + if not _snapshot_exists(stats_doc): + stats_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + stats_payload["createdAtIso"] = now_iso + + stats_ref.set(stats_payload, merge=True) + except Exception as monitor_err: + logger.warning(f"Risk refresh monitor logging failed ({refresh_id}, {status}): {monitor_err}") + + +def _queue_post_import_risk_refresh( + request: Request, + *, + students: List[Dict[str, Any]], + column_mapping: Dict[str, str], + class_section_id: Optional[str] = None, +) -> Dict[str, Any]: + """Queue non-blocking automation refresh after class-record imports.""" + if not students: + return { + "queued": False, + "studentsQueued": 0, + "reason": "No normalized students to process.", + "refreshId": None, + "queuedAtEpoch": None, + } + + # Keep payload compact while preserving key risk-driving fields. + compact_students = [_build_scoring_student_payload(row, class_section_id) for row in students] + + user = get_current_user(request) + normalized_class_section_id = (class_section_id or "").strip() or None + queued_at_epoch = int(time.time()) + refresh_seed = f"{user.uid}|{normalized_class_section_id or 'global'}|{queued_at_epoch}|{len(compact_students)}" + refresh_id = hashlib.sha1(refresh_seed.encode("utf-8")).hexdigest()[:24] + + sanitized_mapping = _sanitize_column_mapping(column_mapping) + payload = DataImportPayload( + teacherId=user.uid, + students=compact_students, + columnMapping=sanitized_mapping, + ) + + _record_risk_refresh_event( + teacher_id=user.uid, + teacher_email=user.email, + class_section_id=normalized_class_section_id, + refresh_id=refresh_id, + status="queued", + students_queued=len(compact_students), + queued_at_epoch=queued_at_epoch, + ) + + async def _run_automation_job() -> None: + started_at_epoch = int(time.time()) + try: + result = await automation_engine.handle_data_import(payload) + completed_at_epoch = int(time.time()) + duration_ms = max(0, int((completed_at_epoch - started_at_epoch) * 1000)) + final_status = "success" if bool(getattr(result, "success", False)) else "failed" + _record_risk_refresh_event( + teacher_id=user.uid, + teacher_email=user.email, + class_section_id=normalized_class_section_id, + refresh_id=refresh_id, + status=final_status, + students_queued=len(compact_students), + queued_at_epoch=queued_at_epoch, + started_at_epoch=started_at_epoch, + completed_at_epoch=completed_at_epoch, + duration_ms=duration_ms, + metadata={ + "automationSuccess": bool(getattr(result, "success", False)), + "message": str(getattr(result, "message", "") or ""), + "actionsCount": len(getattr(result, "actions", []) or []), + }, + ) + logger.info( + "Post-import automation completed for teacher %s (refreshId=%s, queued=%s, success=%s)", + user.uid, + refresh_id, + len(compact_students), + result.success, + ) + except Exception as automation_exc: + completed_at_epoch = int(time.time()) + duration_ms = max(0, int((completed_at_epoch - started_at_epoch) * 1000)) + _record_risk_refresh_event( + teacher_id=user.uid, + teacher_email=user.email, + class_section_id=normalized_class_section_id, + refresh_id=refresh_id, + status="failed", + students_queued=len(compact_students), + queued_at_epoch=queued_at_epoch, + started_at_epoch=started_at_epoch, + completed_at_epoch=completed_at_epoch, + duration_ms=duration_ms, + metadata={ + "error": str(automation_exc), + }, + ) + logger.error( + "Post-import automation failed for teacher %s (refreshId=%s): %s", + user.uid, + refresh_id, + automation_exc, + ) + + asyncio.create_task(_run_automation_job()) + return { + "queued": True, + "studentsQueued": len(compact_students), + "reason": None, + "refreshId": refresh_id, + "queuedAtEpoch": queued_at_epoch, + } + + +def _prune_account_import_previews(now_ts: Optional[float] = None) -> None: + cutoff = (now_ts if now_ts is not None else time.time()) - ACCOUNT_IMPORT_PREVIEW_TTL_SECONDS + expired_tokens = [ + token + for token, payload in _account_import_previews.items() + if float(payload.get("createdAtTs", 0.0)) < cutoff + ] + for token in expired_tokens: + _account_import_previews.pop(token, None) + + +def _auth_user_not_found(error: Exception) -> bool: + message = str(error).lower() + return "not found" in message or "no user record" in message + + +def _generate_temporary_password(length: int = 12) -> str: + alphabet = string.ascii_letters + string.digits + while True: + candidate = "".join(secrets.choice(alphabet) for _ in range(max(10, length))) + if any(ch.islower() for ch in candidate) and any(ch.isupper() for ch in candidate) and any(ch.isdigit() for ch in candidate): + return candidate + + +def _parse_provisioning_dataframe(df: Any) -> Dict[str, Any]: + header_map: Dict[str, str] = {} + for column in df.columns.tolist(): + normalized = re.sub(r"[^a-z0-9]+", "", str(column or "").strip().lower()) + if normalized and normalized not in header_map: + header_map[normalized] = str(column) + + def _pick(row: Any, candidates: Sequence[str]) -> str: + for alias in candidates: + source = header_map.get(alias) + if source is None: + continue + return _stringify_cell(row.get(source)) + return "" + + parsed_rows: List[Dict[str, Any]] = [] + for idx, row in df.iterrows(): + first_name = _pick(row, ["firstname", "first", "givenname", "given"]) + last_name = _pick(row, ["lastname", "last", "surname", "familyname"]) + middle_name = _pick(row, ["middlename", "middle", "middlenameinitial"]) or "" + student_id = _pick(row, ["studentid", "lrn", "learnerid", "learnerreferencenumber", "schoolid"]) + email = _pick(row, ["email", "emailaddress", "studentemail"]).lower() + grade = _pick(row, ["grade", "gradelevel", "yearlevel"]) + section = _pick(row, ["section", "classsection", "homeroom", "sectionname"]) + + parsed_rows.append( + { + "rowNumber": int(idx) + 2, + "firstName": first_name, + "lastName": last_name, + "middleName": middle_name, + "studentId": student_id, + "email": email, + "grade": grade, + "section": section, + } + ) + + return { + "rows": parsed_rows, + "headerMap": header_map, + } + + +def _teacher_can_manage_section(user: AuthenticatedUser, class_section_id: str) -> bool: + if user.role == "admin": + return True + if not (_firebase_ready and firebase_firestore): + return False + + normalized_section = (class_section_id or "").strip().lower() + if not normalized_section: + return False + + try: + client = firebase_firestore.client() + ownership_ref = client.collection("classSectionOwnership").document(normalized_section) + ownership_doc = cast(Any, ownership_ref.get()) + if _snapshot_exists(ownership_doc): + ownership = _snapshot_to_dict(ownership_doc) + if str(ownership.get("ownerTeacherId") or "").strip() == user.uid: + return True + if str(ownership.get("managerId") or "").strip() == user.uid: + return True + + for field in ("teacherId", "ownerTeacherId", "managerId"): + query = ( + client.collection("classrooms") + .where("classSectionId", "==", normalized_section) + .where(field, "==", user.uid) + .limit(1) + ) + if list(query.stream()): + return True + except Exception as err: + logger.warning(f"Section ownership check failed for {class_section_id}: {err}") + + return False + + +class StudentAccountProvisionCommitRequest(BaseModel): + previewToken: str + defaultPassword: Optional[str] = None + forcePasswordChange: bool = True + createAuthUsers: bool = True + + +class AdminCreateUserRequest(BaseModel): + name: str + email: str + password: str + confirmPassword: str + role: str + status: str + grade: str + section: str + lrn: Optional[str] = None + + @field_validator("name", "email", "password", "confirmPassword", "role", "status", "grade", "section") + @classmethod + def _strip_required(cls, value: str) -> str: + return str(value or "").strip() + + @field_validator("lrn") + @classmethod + def _strip_lrn(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return None + cleaned = value.strip() + return cleaned or None + + +class AdminCreateUserResponse(BaseModel): + success: bool + resultCode: str + message: str + userCreated: bool + emailSent: bool + uid: Optional[str] = None + warnings: List[str] = Field(default_factory=list) + emailError: Optional[Dict[str, Any]] = None + + +class AdminDeleteUserResponse(BaseModel): + success: bool + uid: str + authDeleted: bool + profileDeleted: bool + message: str + warnings: List[str] = Field(default_factory=list) + + +class AdminUserListItem(BaseModel): + uid: str + name: str + email: str + role: str + status: str + department: str + grade: Optional[str] = None + section: Optional[str] = None + classSectionId: Optional[str] = None + lrn: Optional[str] = None + photo: Optional[str] = None + lastLogin: Optional[str] = None + createdAt: Optional[str] = None + + +class AdminUserListResponse(BaseModel): + success: bool + page: int + pageSize: int + total: int + totalPages: int + hasNextPage: bool + hasMore: bool + users: List[AdminUserListItem] + filters: Dict[str, Optional[str]] = Field(default_factory=dict) + + +class AdminUpdateUserRequest(BaseModel): + name: Optional[str] = None + role: Optional[str] = None + status: Optional[str] = None + department: Optional[str] = None + grade: Optional[str] = None + section: Optional[str] = None + lrn: Optional[str] = None + + @field_validator("name", "role", "status", "department", "grade", "section", "lrn", mode="before") + @classmethod + def _strip_optional_fields(cls, value: Any) -> Optional[str]: + if value is None: + return None + cleaned = str(value).strip() + return cleaned or None + + +class AdminUpdateUserResponse(BaseModel): + success: bool + uid: str + message: str + updatesApplied: Dict[str, Any] = Field(default_factory=dict) + warnings: List[str] = Field(default_factory=list) + + +class AdminBulkActionFilters(BaseModel): + search: Optional[str] = None + role: Optional[str] = None + status: Optional[str] = None + grade: Optional[str] = None + section: Optional[str] = None + classSectionId: Optional[str] = None + + @field_validator("search", "role", "status", "grade", "section", "classSectionId", mode="before") + @classmethod + def _strip_optional_fields(cls, value: Any) -> Optional[str]: + if value is None: + return None + cleaned = str(value).strip() + return cleaned or None + + +class AdminBulkActionRequest(BaseModel): + action: str + userIds: List[str] = Field(default_factory=list) + excludeUserIds: List[str] = Field(default_factory=list) + filters: Optional[AdminBulkActionFilters] = None + role: Optional[str] = None + status: Optional[str] = None + grade: Optional[str] = None + section: Optional[str] = None + lrn: Optional[str] = None + dryRun: bool = False + exportFormat: str = "csv" + + @field_validator("action", mode="before") + @classmethod + def _strip_action(cls, value: Any) -> str: + cleaned = str(value or "").strip() + if not cleaned: + raise ValueError("action is required") + return cleaned + + @field_validator("role", "status", "grade", "section", "lrn", mode="before") + @classmethod + def _strip_optional_fields(cls, value: Any) -> Optional[str]: + if value is None: + return None + cleaned = str(value).strip() + return cleaned or None + + +class AdminBulkActionResultItem(BaseModel): + uid: str + email: Optional[str] = None + status: str + message: str + + +class AdminBulkActionResponse(BaseModel): + success: bool + action: str + summary: Dict[str, int] + results: List[AdminBulkActionResultItem] + warnings: List[str] = Field(default_factory=list) + export: Optional[Dict[str, Any]] = None + + +ADMIN_BULK_ACTIONS: Set[str] = { + "change_role", + "change_status", + "assign_class_section", + "activate", + "deactivate", + "reset_password_email", + "delete", + "export", +} + + +def _strip_optional_string(value: Any) -> Optional[str]: + if value is None: + return None + cleaned = str(value).strip() + return cleaned or None + + +def _normalize_admin_role_value(role: str) -> str: + normalized = str(role or "").strip().lower() + if normalized not in VALID_ROLES: + raise HTTPException(status_code=400, detail="Role must be Student, Teacher, or Admin.") + return normalized + + +def _normalize_admin_status_value(status: str) -> str: + normalized = str(status or "").strip().lower() + if normalized not in {"active", "inactive"}: + raise HTTPException(status_code=400, detail="Status must be Active or Inactive.") + return "Active" if normalized == "active" else "Inactive" + + +def _coerce_iso_timestamp(value: Any) -> Optional[str]: + if value is None: + return None + + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc).isoformat() + + if isinstance(value, str): + cleaned = value.strip() + return cleaned or None + + to_datetime = getattr(value, "to_datetime", None) + if callable(to_datetime): + try: + converted = to_datetime() + if isinstance(converted, datetime): + if converted.tzinfo is None: + converted = converted.replace(tzinfo=timezone.utc) + return converted.astimezone(timezone.utc).isoformat() + except Exception: + pass + + to_date = getattr(value, "toDate", None) + if callable(to_date): + try: + converted = to_date() + if isinstance(converted, datetime): + if converted.tzinfo is None: + converted = converted.replace(tzinfo=timezone.utc) + return converted.astimezone(timezone.utc).isoformat() + except Exception: + pass + + seconds = getattr(value, "seconds", None) + if isinstance(seconds, (int, float)): + try: + return datetime.fromtimestamp(float(seconds), tz=timezone.utc).isoformat() + except Exception: + pass + + return str(value) + + +def _iso_sort_key(value: Optional[str]) -> float: + if not value: + return 0.0 + try: + return datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() + except Exception: + return 0.0 + + +def _build_admin_user_record(uid: str, data: Dict[str, Any]) -> Dict[str, Any]: + role_lower = str(data.get("role") or "student").strip().lower() + if role_lower not in VALID_ROLES: + role_lower = "student" + + status_value = str(data.get("status") or "Active").strip().lower() + status_display = "Active" if status_value == "active" else "Inactive" + + grade = _strip_optional_string(data.get("grade")) + section = _strip_optional_string(data.get("section")) + class_section_id = _strip_optional_string(data.get("classSectionId")) + if not class_section_id and grade and section: + class_section_id = re.sub(r"\s+", "_", f"{grade}_{section}".strip()).lower() + + department = _strip_optional_string(data.get("department")) or "" + if role_lower == "student" and not department: + department = " - ".join([part for part in [grade, section] if part]) + + return { + "uid": uid, + "name": _strip_optional_string(data.get("name")) or "Unknown", + "email": _strip_optional_string(data.get("email")) or "", + "role": role_lower.capitalize(), + "status": status_display, + "department": department, + "grade": grade, + "section": section, + "classSectionId": class_section_id, + "lrn": _strip_optional_string(data.get("lrn")), + "photo": _strip_optional_string(data.get("photo")) or _strip_optional_string(data.get("photoURL")), + "lastLogin": _coerce_iso_timestamp(data.get("lastLogin")), + "createdAt": _coerce_iso_timestamp(data.get("createdAt")), + } + + +def _filter_admin_user_records( + records: Sequence[Dict[str, Any]], + *, + search: Optional[str] = None, + role: Optional[str] = None, + status: Optional[str] = None, + grade: Optional[str] = None, + section: Optional[str] = None, + class_section_id: Optional[str] = None, +) -> List[Dict[str, Any]]: + search_term = (search or "").strip().lower() + role_filter = (role or "").strip().lower() + status_filter = (status or "").strip().lower() + grade_filter = (grade or "").strip().lower() + section_filter = (section or "").strip().lower() + class_section_filter = (class_section_id or "").strip().lower() + + filtered: List[Dict[str, Any]] = [] + for record in records: + if search_term: + searchable = " ".join( + [ + str(record.get("uid") or ""), + str(record.get("name") or ""), + str(record.get("email") or ""), + ] + ).lower() + if search_term not in searchable: + continue + + if role_filter and role_filter not in {"all", "all roles"}: + if str(record.get("role") or "").strip().lower() != role_filter: + continue + + if status_filter and status_filter not in {"all", "all status"}: + if str(record.get("status") or "").strip().lower() != status_filter: + continue + + if grade_filter and str(record.get("grade") or "").strip().lower() != grade_filter: + continue + + if section_filter and str(record.get("section") or "").strip().lower() != section_filter: + continue + + if class_section_filter and str(record.get("classSectionId") or "").strip().lower() != class_section_filter: + continue + + filtered.append(record) + + return filtered + + +def _load_all_admin_user_records(firestore_client: Any) -> List[Dict[str, Any]]: + docs = list(firestore_client.collection("users").stream()) + records: List[Dict[str, Any]] = [] + for doc in docs: + data = _snapshot_to_dict(doc) + records.append(_build_admin_user_record(str(doc.id), data)) + + records.sort( + key=lambda item: _iso_sort_key(item.get("createdAt") or item.get("lastLogin")), + reverse=True, + ) + return records + + +async def _load_admin_user_records_for_list( + firestore_client: Any, + *, + scan_limit: int, + role: Optional[str], + class_section_id: Optional[str], +) -> List[Dict[str, Any]]: + normalized_role = (role or "").strip().lower() + normalized_class_section_id = (class_section_id or "").strip().lower() + + base_query = firestore_client.collection("users") + query_ref: Any = base_query + applied_filter: Optional[Tuple[str, str]] = None + + if normalized_class_section_id: + query_ref = query_ref.where("classSectionId", "==", normalized_class_section_id) + applied_filter = ("classSectionId", normalized_class_section_id) + elif normalized_role in VALID_ROLES: + query_ref = query_ref.where("role", "==", normalized_role) + applied_filter = ("role", normalized_role) + + query_ref = query_ref.limit(scan_limit) + + try: + docs = await asyncio.wait_for( + asyncio.to_thread(lambda: list(cast(Any, query_ref).stream())), + timeout=ADMIN_USERS_QUERY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=504, + detail="User list query timed out. Please narrow your filters and try again.", + ) + except Exception as query_error: + query_error_text = str(query_error).lower() + if applied_filter and ("failed-precondition" in query_error_text or "index" in query_error_text): + logger.warning( + "Admin user list fallback scan due to index issue (%s=%s): %s", + applied_filter[0], + applied_filter[1], + query_error, + ) + fallback_query = base_query.limit(scan_limit) + try: + docs = await asyncio.wait_for( + asyncio.to_thread(lambda: list(cast(Any, fallback_query).stream())), + timeout=ADMIN_USERS_QUERY_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + raise HTTPException( + status_code=504, + detail="User list query timed out. Please narrow your filters and try again.", + ) + except Exception as fallback_error: + logger.error("Admin user list fallback query error: %s", fallback_error) + raise HTTPException(status_code=500, detail="Failed to load admin users.") + else: + logger.error("Admin user list query error: %s", query_error) + raise HTTPException(status_code=500, detail="Failed to load admin users.") + + records: List[Dict[str, Any]] = [] + for doc in docs: + data = _snapshot_to_dict(doc) + records.append(_build_admin_user_record(str(doc.id), data)) + + records.sort( + key=lambda item: _iso_sort_key(item.get("createdAt") or item.get("lastLogin")), + reverse=True, + ) + return records + + +def _resolve_admin_bulk_target_records( + firestore_client: Any, + *, + user_ids: Sequence[str], + filters: Optional[AdminBulkActionFilters], + excluded_user_ids: Sequence[str], +) -> List[Dict[str, Any]]: + excluded = {str(uid or "").strip() for uid in excluded_user_ids if str(uid or "").strip()} + normalized_user_ids = [str(uid or "").strip() for uid in user_ids if str(uid or "").strip()] + + records: List[Dict[str, Any]] = [] + if normalized_user_ids: + seen: Set[str] = set() + for uid in normalized_user_ids: + if uid in seen or uid in excluded: + continue + seen.add(uid) + doc = cast(Any, firestore_client.collection("users").document(uid).get()) + if not _snapshot_exists(doc): + continue + records.append(_build_admin_user_record(uid, _snapshot_to_dict(doc))) + return records + + all_records = _load_all_admin_user_records(firestore_client) + filtered = _filter_admin_user_records( + all_records, + search=filters.search if filters else None, + role=filters.role if filters else None, + status=filters.status if filters else None, + grade=filters.grade if filters else None, + section=filters.section if filters else None, + class_section_id=filters.classSectionId if filters else None, + ) + + return [record for record in filtered if str(record.get("uid") or "") not in excluded] + + +def _build_password_reset_email_message(name: str, email: str, reset_link: str) -> EmailMessagePayload: + safe_name = html.escape(name or "Learner") + safe_link = html.escape(reset_link, quote=True) + html_body = ( + "
" + "
" + "

Password Reset Requested

" + f"

Hello {safe_name}, your administrator requested a password reset for your MathPulse AI account.

" + f"

Reset Password

" + "

If you did not request this change, contact your administrator immediately.

" + "
" + ) + text_body = ( + "MathPulse AI Password Reset\n\n" + f"Hello {name or 'Learner'},\n\n" + "An administrator requested a password reset for your account.\n" + f"Reset your password here: {reset_link}\n\n" + "If you did not request this change, contact your administrator immediately.\n" + ) + return EmailMessagePayload( + to_name=name or "Learner", + to_email=email, + subject="MathPulse AI Password Reset", + html_content=html_body, + text_content=text_body, + ) + + +def _prepare_admin_profile_updates(existing: Dict[str, Any], payload: Dict[str, Any]) -> Dict[str, Any]: + updates: Dict[str, Any] = {"updatedAt": FIRESTORE_SERVER_TIMESTAMP} + + if payload.get("name") is not None: + updates["name"] = str(payload.get("name") or "").strip() + + role_value = payload.get("role") + role_lower = str(existing.get("role") or "student").strip().lower() + if role_value is not None: + role_lower = _normalize_admin_role_value(str(role_value)) + updates["role"] = role_lower + + status_value = payload.get("status") + if status_value is not None: + updates["status"] = _normalize_admin_status_value(str(status_value)) + + if payload.get("department") is not None: + updates["department"] = str(payload.get("department") or "").strip() + + if payload.get("grade") is not None: + updates["grade"] = str(payload.get("grade") or "").strip() + + if payload.get("section") is not None: + updates["section"] = str(payload.get("section") or "").strip() + + if payload.get("lrn") is not None: + updates["lrn"] = str(payload.get("lrn") or "").strip() + + if role_lower == "student": + grade_value = str(updates.get("grade") or existing.get("grade") or "").strip() + section_value = str(updates.get("section") or existing.get("section") or "").strip() + lrn_value = str(updates.get("lrn") or existing.get("lrn") or "").strip() + if not lrn_value: + raise HTTPException(status_code=400, detail="LRN is required for student accounts.") + updates["lrn"] = lrn_value + if grade_value: + updates["grade"] = grade_value + if section_value: + updates["section"] = section_value + if grade_value and section_value: + updates["classSectionId"] = re.sub(r"\s+", "_", f"{grade_value}_{section_value}".strip()).lower() + + return updates + + +@app.post("/api/import/student-accounts/preview") +async def preview_student_account_import( + request: Request, + file: UploadFile = File(...), + classSectionId: Optional[str] = Form(default=None), + className: Optional[str] = Form(default=None), + defaultGrade: Optional[str] = Form(default=None), + defaultSection: Optional[str] = Form(default=None), +): + """Parse and validate student account import rows before provisioning.""" + try: + import pandas as pd # type: ignore[import-not-found] + + enforce_rate_limit(request, "import_student_accounts_preview", 8, 60) + user = get_current_user(request) + + filename = file.filename or "" + ext = os.path.splitext(filename)[1].lower() + if ext not in {".csv", ".xlsx", ".xls"}: + raise HTTPException(status_code=400, detail="Unsupported file format. Use .csv, .xlsx, or .xls") + + contents = await file.read(UPLOAD_MAX_BYTES + 1) + if len(contents) > UPLOAD_MAX_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large. Max allowed size is {UPLOAD_MAX_BYTES // (1024 * 1024)} MB.", + ) + + if ext == ".csv": + df = pd.read_csv(io.BytesIO(contents), on_bad_lines="skip") + else: + df = pd.read_excel(io.BytesIO(contents)) + + if df is None or df.empty: + raise HTTPException(status_code=400, detail="No rows found in uploaded file") + if df.shape[0] > UPLOAD_MAX_ROWS: + raise HTTPException(status_code=413, detail=f"Too many rows ({df.shape[0]}). Max allowed: {UPLOAD_MAX_ROWS}") + + parsed = _parse_provisioning_dataframe(df) + parsed_rows = cast(List[Dict[str, Any]], parsed.get("rows") or []) + default_section_id, default_class_name, inferred_grade, inferred_section = _resolve_import_class_context( + class_section_id=(classSectionId or "").strip() or None, + class_name=(className or "").strip() or None, + ) + effective_default_grade = (defaultGrade or "").strip() or inferred_grade + effective_default_section = (defaultSection or "").strip() or inferred_section + + seen_student_ids: Set[str] = set() + seen_emails: Set[str] = set() + preview_rows: List[Dict[str, Any]] = [] + warnings: List[str] = [] + valid_rows: List[Dict[str, Any]] = [] + + firestore_client = firebase_firestore.client() if (_firebase_ready and firebase_firestore) else None + for row in parsed_rows: + first_name = str(row.get("firstName") or "").strip() + last_name = str(row.get("lastName") or "").strip() + middle_name = str(row.get("middleName") or "").strip() + student_id = str(row.get("studentId") or "").strip() + email = str(row.get("email") or "").strip().lower() + grade = str(row.get("grade") or "").strip() or effective_default_grade + section = str(row.get("section") or "").strip() or effective_default_section + full_name = " ".join(part for part in [first_name, middle_name, last_name] if part).strip() + + issues: List[str] = [] + if not first_name: + issues.append("Missing firstName") + if not last_name: + issues.append("Missing lastName") + if not student_id: + issues.append("Missing studentId/lrn") + if not grade: + issues.append("Missing grade") + if not section: + issues.append("Missing section") + + resolved_row_section_id, _, resolved_grade, resolved_section = _resolve_import_class_context( + class_section_id=None, + class_name=f"{grade or effective_default_grade} - {section or effective_default_section}", + ) + + generated_email = email + if not generated_email and student_id: + generated_email = f"{_slugify_class_token(student_id)}@mathpulse.local" + + duplicate_in_file = False + duplicate_in_firestore = False + duplicate_in_auth = False + + student_id_key = student_id.lower() + email_key = generated_email.lower() + if student_id_key and student_id_key in seen_student_ids: + duplicate_in_file = True + issues.append("Duplicate studentId in file") + if email_key and email_key in seen_emails: + duplicate_in_file = True + issues.append("Duplicate email in file") + + if student_id_key: + seen_student_ids.add(student_id_key) + if email_key: + seen_emails.add(email_key) + + if firestore_client: + try: + if student_id: + existing_by_lrn = list( + firestore_client.collection("users").where("lrn", "==", student_id).limit(1).stream() + ) + duplicate_in_firestore = duplicate_in_firestore or len(existing_by_lrn) > 0 + if generated_email: + existing_by_email = list( + firestore_client.collection("users").where("email", "==", generated_email).limit(1).stream() + ) + duplicate_in_firestore = duplicate_in_firestore or len(existing_by_email) > 0 + except Exception as firestore_err: + warnings.append(f"Firestore duplicate check skipped for row {row.get('rowNumber')}: {firestore_err}") + + if generated_email and firebase_auth: + try: + cast(Any, firebase_auth).get_user_by_email(generated_email) + duplicate_in_auth = True + except Exception as auth_err: + if not _auth_user_not_found(cast(Exception, auth_err)): + warnings.append(f"Auth duplicate check warning for {generated_email}: {auth_err}") + + if duplicate_in_firestore: + issues.append("Duplicate with existing Firestore profile") + if duplicate_in_auth: + issues.append("Duplicate with existing Auth account") + + status = "valid" + if issues: + status = "invalid" + if duplicate_in_file or duplicate_in_firestore or duplicate_in_auth: + status = "duplicate" + + preview_row = { + "rowNumber": int(row.get("rowNumber") or 0), + "studentId": student_id, + "firstName": first_name, + "lastName": last_name, + "middleName": middle_name, + "fullName": full_name or "Imported Student", + "email": generated_email, + "grade": resolved_grade, + "section": resolved_section, + "classSectionId": resolved_row_section_id, + "status": status, + "issues": issues, + "duplicateInFile": duplicate_in_file, + "duplicateInFirestore": duplicate_in_firestore, + "duplicateInAuth": duplicate_in_auth, + } + preview_rows.append(preview_row) + if status == "valid": + valid_rows.append(preview_row) + + preview_token = uuid.uuid4().hex + with _account_import_previews_lock: + _prune_account_import_previews() + _account_import_previews[preview_token] = { + "createdAtTs": time.time(), + "ownerUid": user.uid, + "classSectionId": default_section_id, + "className": default_class_name, + "rows": preview_rows, + "validRows": valid_rows, + } + + invalid_rows = sum(1 for row in preview_rows if row.get("status") == "invalid") + duplicate_rows = sum(1 for row in preview_rows if row.get("status") == "duplicate") + return { + "success": True, + "previewToken": preview_token, + "classSectionId": default_section_id, + "className": default_class_name, + "summary": { + "totalRows": len(preview_rows), + "validRows": len(valid_rows), + "invalidRows": invalid_rows, + "duplicateRows": duplicate_rows, + }, + "rows": preview_rows, + "warnings": list(dict.fromkeys(warnings))[:50], + } + except HTTPException: + raise + except Exception as exc: + logger.error(f"Student account preview error: {exc}") + raise HTTPException(status_code=500, detail=f"Student account preview error: {str(exc)}") + + +@app.post("/api/import/student-accounts/commit") +async def commit_student_account_import( + request: Request, + payload: StudentAccountProvisionCommitRequest, +): + """Commit previously validated student-account preview rows into Auth + Firestore.""" + try: + user = get_current_user(request) + preview_token = payload.previewToken.strip() + if not preview_token: + raise HTTPException(status_code=400, detail="previewToken is required") + + with _account_import_previews_lock: + _prune_account_import_previews() + preview = _account_import_previews.get(preview_token) + + if not preview: + raise HTTPException(status_code=404, detail="Preview token not found or expired") + if str(preview.get("ownerUid") or "") != user.uid: + raise HTTPException(status_code=403, detail="Preview token does not belong to the current user") + + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + client = firebase_firestore.client() + preview_rows = cast(List[Dict[str, Any]], preview.get("rows") or []) + candidate_rows = [row for row in preview_rows if str(row.get("status") or "") == "valid"] + + result_rows: List[Dict[str, Any]] = [] + warnings: List[str] = [] + created_rows = 0 + updated_rows = 0 + skipped_rows = 0 + blocked_rows = 0 + failed_rows = 0 + + for row in candidate_rows: + row_number = int(row.get("rowNumber") or 0) + student_id = str(row.get("studentId") or "").strip() + full_name = str(row.get("fullName") or "Imported Student").strip() or "Imported Student" + email = str(row.get("email") or "").strip().lower() + grade = str(row.get("grade") or "Grade 11").strip() or "Grade 11" + section = str(row.get("section") or "Section A").strip() or "Section A" + class_section_id = str(row.get("classSectionId") or "").strip().lower() or _slugify_class_token(f"{grade}_{section}") + + if user.role != "admin" and not _teacher_can_manage_section(user, class_section_id): + blocked_rows += 1 + result_rows.append( + { + "rowNumber": row_number, + "studentId": student_id, + "fullName": full_name, + "email": email, + "uid": None, + "classSectionId": class_section_id, + "status": "blocked", + "message": "Teacher is not allowed to provision this class section.", + "temporaryPassword": None, + } + ) + continue + + existing_profile_doc = None + existing_uid: Optional[str] = None + try: + if student_id: + docs = list(client.collection("users").where("lrn", "==", student_id).limit(1).stream()) + if docs: + existing_profile_doc = docs[0] + existing_uid = docs[0].id + if not existing_uid and email: + docs = list(client.collection("users").where("email", "==", email).limit(1).stream()) + if docs: + existing_profile_doc = docs[0] + existing_uid = docs[0].id + except Exception as duplicate_err: + warnings.append(f"Duplicate re-check warning for row {row_number}: {duplicate_err}") + + auth_uid: Optional[str] = None + temporary_password: Optional[str] = None + auth_existing = False + + if payload.createAuthUsers and firebase_auth and email: + try: + auth_user = cast(Any, firebase_auth).get_user_by_email(email) + auth_uid = str(getattr(auth_user, "uid", "") or "").strip() or None + auth_existing = auth_uid is not None + except Exception as auth_lookup_err: + if not _auth_user_not_found(cast(Exception, auth_lookup_err)): + warnings.append(f"Auth lookup warning for {email}: {auth_lookup_err}") + + if payload.createAuthUsers and firebase_auth and not auth_uid and email: + try: + temporary_password = (payload.defaultPassword or "").strip() or _generate_temporary_password() + created_auth_user = cast(Any, firebase_auth).create_user( + email=email, + password=temporary_password, + display_name=full_name, + ) + auth_uid = str(getattr(created_auth_user, "uid", "") or "").strip() or None + except Exception as create_auth_err: + failed_rows += 1 + result_rows.append( + { + "rowNumber": row_number, + "studentId": student_id, + "fullName": full_name, + "email": email, + "uid": None, + "classSectionId": class_section_id, + "status": "failed", + "message": f"Auth user provisioning failed: {create_auth_err}", + "temporaryPassword": None, + } + ) + continue + + if payload.createAuthUsers and auth_uid: + # Always write profile under auth_uid so lookup via request.auth.uid works at login. + if existing_uid and existing_uid != auth_uid: + warnings.append( + f"Row {row_number}: existing Firestore profile UID ({existing_uid}) differs from " + f"Auth UID ({auth_uid}); profile will be written under Auth UID." + ) + target_uid = auth_uid + else: + target_uid = existing_uid or auth_uid + if not target_uid: + seed = f"{student_id}|{email}|{class_section_id}" + target_uid = hashlib.sha1(seed.encode("utf-8")).hexdigest()[:28] + + class_section_name = f"{grade} - {section}" + class_metadata = _build_class_metadata( + class_section_id=class_section_id, + class_name=class_section_name, + grade=grade, + section=section, + owner_teacher_id=user.uid, + owner_teacher_name=user.email, + adviser_teacher_id=user.uid, + adviser_teacher_name=user.email, + manager_id=user.uid, + manager_name=user.email, + ) + + profile_payload: Dict[str, Any] = { + "name": full_name, + "email": email, + "role": "student", + "lrn": student_id, + "grade": grade, + "gradeLevel": class_metadata.get("gradeLevel"), + "classification": class_metadata.get("classification"), + "strand": class_metadata.get("strand"), + "section": section, + "classSectionId": class_section_id, + "adviserTeacherId": user.uid, + "adviserTeacherName": user.email, + "forcePasswordChange": bool(payload.forcePasswordChange), + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if not existing_profile_doc: + profile_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + + client.collection("users").document(target_uid).set(profile_payload, merge=True) + + managed_student_payload = { + "teacherId": user.uid, + "name": full_name, + "email": email, + "lrn": student_id, + "className": class_section_name, + "grade": grade, + "gradeLevel": class_metadata.get("gradeLevel"), + "classification": class_metadata.get("classification"), + "strand": class_metadata.get("strand"), + "section": section, + "classSectionId": class_section_id, + "classroomId": class_section_id, + "managerId": user.uid, + "managerName": user.email, + "classMetadata": class_metadata, + "riskLevel": "Low", + "engagementScore": 0, + "avgQuizScore": 0, + "weakestTopic": "Foundational Skills", + "attendance": 0, + "assignmentCompletion": 0, + "struggles": ["Foundational Skills"], + "lastActive": FIRESTORE_SERVER_TIMESTAMP, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + client.collection("managedStudents").document(target_uid).set(managed_student_payload, merge=True) + + ownership_ref = client.collection("classSectionOwnership").document(class_section_id) + ownership_doc = cast(Any, ownership_ref.get()) + ownership_data = _snapshot_to_dict(ownership_doc) if _snapshot_exists(ownership_doc) else {} + existing_student_uids = cast(List[str], ownership_data.get("studentUids") or []) + merged_student_uids = sorted(set([*existing_student_uids, target_uid])) + ownership_payload = { + "classSectionId": class_section_id, + "className": class_section_name, + "grade": grade, + "gradeLevel": class_metadata.get("gradeLevel"), + "classification": class_metadata.get("classification"), + "strand": class_metadata.get("strand"), + "section": section, + "schoolYear": str(datetime.now(timezone.utc).year), + "ownerTeacherId": str(ownership_data.get("ownerTeacherId") or user.uid), + "ownerTeacherName": str(ownership_data.get("ownerTeacherName") or user.email), + "managerId": str(ownership_data.get("managerId") or user.uid), + "managerName": str(ownership_data.get("managerName") or user.email), + "studentUids": merged_student_uids, + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if not _snapshot_exists(ownership_doc): + ownership_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + ownership_ref.set(ownership_payload, merge=True) + + row_status = "updated" if (existing_profile_doc or auth_existing) else "created" + if row_status == "created": + created_rows += 1 + else: + updated_rows += 1 + + result_rows.append( + { + "rowNumber": row_number, + "studentId": student_id, + "fullName": full_name, + "email": email, + "uid": target_uid, + "classSectionId": class_section_id, + "status": row_status, + "message": "Provisioned successfully." if row_status == "created" else "Existing profile updated.", + "temporaryPassword": temporary_password, + } + ) + + skipped_rows = max(0, len(preview_rows) - len(candidate_rows)) + + with _account_import_previews_lock: + _account_import_previews.pop(preview_token, None) + + _write_access_audit_log( + request, + action="student_account_import_commit", + status="success", + class_section_id=str(preview.get("classSectionId") or "") or None, + metadata={ + "previewToken": preview_token, + "totalRows": len(preview_rows), + "candidateRows": len(candidate_rows), + "createdRows": created_rows, + "updatedRows": updated_rows, + "skippedRows": skipped_rows, + "blockedRows": blocked_rows, + "failedRows": failed_rows, + }, + ) + + return { + "success": failed_rows == 0, + "previewToken": preview_token, + "summary": { + "totalRows": len(preview_rows), + "createdRows": created_rows, + "updatedRows": updated_rows, + "skippedRows": skipped_rows, + "blockedRows": blocked_rows, + "failedRows": failed_rows, + }, + "rows": result_rows, + "warnings": list(dict.fromkeys(warnings))[:100], + } + except HTTPException: + raise + except Exception as exc: + logger.error(f"Student account import commit error: {exc}") + raise HTTPException(status_code=500, detail=f"Student account import commit error: {str(exc)}") + + +@app.get("/api/admin/users", response_model=AdminUserListResponse) +async def list_admin_users( + request: Request, + page: int = Query(default=1, ge=1), + pageSize: int = Query(default=25, ge=1, le=200), + search: Optional[str] = Query(default=None), + role: Optional[str] = Query(default=None), + status: Optional[str] = Query(default=None), + grade: Optional[str] = Query(default=None), + section: Optional[str] = Query(default=None), + classSectionId: Optional[str] = Query(default=None), +): + user = get_current_user(request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + normalized_role_filter = str(role or "").strip().lower() + if normalized_role_filter and normalized_role_filter not in VALID_ROLES: + raise HTTPException(status_code=400, detail="role must be one of student, teacher, or admin") + + if not _firebase_ready or firebase_firestore is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + start_index = (page - 1) * pageSize + if start_index >= ADMIN_USERS_MAX_SCAN_DOCS: + raise HTTPException( + status_code=400, + detail="Requested page is too deep. Narrow filters or use a smaller page number.", + ) + + required_window = page * pageSize + scan_limit = min(max(required_window * 4, pageSize * 8), ADMIN_USERS_MAX_SCAN_DOCS) + + firestore_client = cast(Any, firebase_firestore).client() + records = await _load_admin_user_records_for_list( + firestore_client, + scan_limit=scan_limit, + role=role, + class_section_id=classSectionId, + ) + filtered = _filter_admin_user_records( + records, + search=search, + role=role, + status=status, + grade=grade, + section=section, + class_section_id=classSectionId, + ) + + total = len(filtered) + total_pages = int(math.ceil(total / pageSize)) if total > 0 else 0 + page_index = page + if total_pages > 0: + page_index = min(page, total_pages) + + start = (page_index - 1) * pageSize + end = start + pageSize + paged_records = filtered[start:end] + + return AdminUserListResponse( + success=True, + page=page_index, + pageSize=pageSize, + total=total, + totalPages=total_pages, + hasNextPage=end < total, + hasMore=end < total, + users=[AdminUserListItem(**entry) for entry in paged_records], + filters={ + "search": _strip_optional_string(search), + "role": _strip_optional_string(role), + "status": _strip_optional_string(status), + "grade": _strip_optional_string(grade), + "section": _strip_optional_string(section), + "classSectionId": _strip_optional_string(classSectionId), + }, + ) + + +@app.patch("/api/admin/users", response_model=AdminUpdateUserResponse) +async def update_admin_user_account( + request: Request, + payload: AdminUpdateUserRequest, + uid: str = Query(..., min_length=1), +): + user = get_current_user(request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + if not _firebase_ready or firebase_firestore is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + normalized_uid = str(uid or "").strip() + if not normalized_uid: + raise HTTPException(status_code=400, detail="User uid is required") + + update_payload = { + "name": payload.name, + "role": payload.role, + "status": payload.status, + "department": payload.department, + "grade": payload.grade, + "section": payload.section, + "lrn": payload.lrn, + } + if not any(value is not None for value in update_payload.values()): + raise HTTPException(status_code=400, detail="At least one field is required for update") + + if normalized_uid == user.uid: + if payload.role and _normalize_admin_role_value(payload.role) != "admin": + raise HTTPException(status_code=400, detail="Admin users cannot remove their own admin role") + if payload.status and _normalize_admin_status_value(payload.status) == "Inactive": + raise HTTPException(status_code=400, detail="Admin users cannot deactivate their own account") + + firestore_client = cast(Any, firebase_firestore).client() + profile_ref = firestore_client.collection("users").document(normalized_uid) + profile_snapshot = cast(Any, profile_ref.get()) + if not _snapshot_exists(profile_snapshot): + raise HTTPException(status_code=404, detail="User profile not found") + + existing_profile = _snapshot_to_dict(profile_snapshot) + prepared_updates = _prepare_admin_profile_updates(existing_profile, update_payload) + profile_ref.set(prepared_updates, merge=True) + + warnings: List[str] = [] + if "status" in prepared_updates and firebase_auth is not None: + try: + cast(Any, firebase_auth).update_user( + normalized_uid, + disabled=(prepared_updates.get("status") == "Inactive"), + ) + except Exception as auth_update_error: + if _is_auth_user_not_found_error(auth_update_error): + warnings.append("Authentication account was already missing while syncing status.") + else: + warnings.append("Failed to sync status with authentication account.") + + _write_access_audit_log( + request, + action="admin_user_update", + status="success", + metadata={ + "uid": normalized_uid, + "fields": sorted([key for key in prepared_updates.keys() if key != "updatedAt"]), + "warnings": list(dict.fromkeys(warnings)), + }, + ) + + updates_applied = { + key: value + for key, value in prepared_updates.items() + if key != "updatedAt" + } + + return AdminUpdateUserResponse( + success=True, + uid=normalized_uid, + message="User profile updated successfully.", + updatesApplied=updates_applied, + warnings=list(dict.fromkeys(warnings)), + ) + + +@app.post("/api/admin/users/bulk-action", response_model=AdminBulkActionResponse) +async def apply_admin_user_bulk_action( + request: Request, + payload: AdminBulkActionRequest, +): + user = get_current_user(request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + if not _firebase_ready or firebase_firestore is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + action = payload.action.strip().lower() + if action not in ADMIN_BULK_ACTIONS: + raise HTTPException(status_code=400, detail="Unsupported bulk action") + + if not payload.userIds and payload.filters is None: + raise HTTPException(status_code=400, detail="Provide userIds or filters to target users") + + if action == "change_role" and not payload.role: + raise HTTPException(status_code=400, detail="role is required for change_role action") + if action == "change_status" and not payload.status: + raise HTTPException(status_code=400, detail="status is required for change_status action") + if action == "assign_class_section" and (not payload.grade or not payload.section): + raise HTTPException(status_code=400, detail="grade and section are required for assign_class_section action") + + firestore_client = cast(Any, firebase_firestore).client() + target_records = _resolve_admin_bulk_target_records( + firestore_client, + user_ids=payload.userIds, + filters=payload.filters, + excluded_user_ids=payload.excludeUserIds, + ) + + if not target_records: + return AdminBulkActionResponse( + success=True, + action=action, + summary={"targeted": 0, "succeeded": 0, "failed": 0, "skipped": 0, "exported": 0}, + results=[], + warnings=[], + export={"format": payload.exportFormat.lower(), "rows": []} if action == "export" else None, + ) + + warnings: List[str] = [] + results: List[AdminBulkActionResultItem] = [] + succeeded = 0 + failed = 0 + skipped = 0 + exported_rows: List[Dict[str, Any]] = [] + + email_service = create_email_service_from_env() if action == "reset_password_email" else None + + for record in target_records: + uid_value = str(record.get("uid") or "").strip() + email_value = _strip_optional_string(record.get("email")) + + if not uid_value: + failed += 1 + results.append( + AdminBulkActionResultItem( + uid="", + email=email_value, + status="failed", + message="Target user is missing uid.", + ) + ) + continue + + if uid_value == user.uid and action in {"delete", "deactivate"}: + skipped += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="skipped", + message="Skipped self-targeted destructive action.", + ) + ) + continue + + if payload.dryRun: + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="Dry run validation succeeded.", + ) + ) + continue + + profile_ref = firestore_client.collection("users").document(uid_value) + profile_snapshot = cast(Any, profile_ref.get()) + if not _snapshot_exists(profile_snapshot): + failed += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="failed", + message="User profile not found.", + ) + ) + continue + + existing_profile = _snapshot_to_dict(profile_snapshot) + + try: + if action == "export": + exported_rows.append( + { + "uid": uid_value, + "name": record.get("name") or "", + "email": record.get("email") or "", + "role": record.get("role") or "", + "status": record.get("status") or "", + "grade": record.get("grade") or "", + "section": record.get("section") or "", + "classSectionId": record.get("classSectionId") or "", + "department": record.get("department") or "", + "lrn": record.get("lrn") or "", + } + ) + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="User exported.", + ) + ) + continue + + if action in {"activate", "deactivate", "change_status"}: + target_status = payload.status if action == "change_status" else ("Active" if action == "activate" else "Inactive") + normalized_status = _normalize_admin_status_value(str(target_status or "")) + profile_ref.set({"status": normalized_status, "updatedAt": FIRESTORE_SERVER_TIMESTAMP}, merge=True) + + if firebase_auth is not None: + try: + cast(Any, firebase_auth).update_user(uid_value, disabled=(normalized_status == "Inactive")) + except Exception as auth_update_error: + if _is_auth_user_not_found_error(auth_update_error): + warnings.append(f"Auth account missing for {uid_value} while syncing status.") + else: + warnings.append(f"Failed to sync auth status for {uid_value}.") + + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message=f"Status set to {normalized_status}.", + ) + ) + continue + + if action == "change_role": + update_payload = _prepare_admin_profile_updates( + existing_profile, + { + "role": payload.role, + "grade": payload.grade, + "section": payload.section, + "lrn": payload.lrn, + }, + ) + if uid_value == user.uid and update_payload.get("role") != "admin": + skipped += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="skipped", + message="Skipped self-admin role removal.", + ) + ) + continue + + profile_ref.set(update_payload, merge=True) + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="Role updated.", + ) + ) + continue + + if action == "assign_class_section": + role_lower = str(existing_profile.get("role") or "").strip().lower() + if role_lower != "student": + skipped += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="skipped", + message="Only student profiles can be assigned to class sections.", + ) + ) + continue + + update_payload = _prepare_admin_profile_updates( + existing_profile, + { + "grade": payload.grade, + "section": payload.section, + "lrn": payload.lrn, + }, + ) + profile_ref.set(update_payload, merge=True) + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="Class and section updated.", + ) + ) + continue + + if action == "reset_password_email": + if firebase_auth is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + if not email_value: + skipped += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="skipped", + message="Skipped user without email address.", + ) + ) + continue + + reset_link = cast(Any, firebase_auth).generate_password_reset_link(email_value) + email_payload = _build_password_reset_email_message( + name=str(existing_profile.get("name") or record.get("name") or "Learner"), + email=email_value, + reset_link=str(reset_link), + ) + + send_result = email_service.send_transactional_email(email_payload) if email_service else None + if send_result is None or not send_result.success: + error_message = send_result.error_message if send_result else "Email service unavailable" + failed += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="failed", + message=f"Failed to send reset email: {error_message}", + ) + ) + continue + + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="Password reset email sent.", + ) + ) + continue + + if action == "delete": + if firebase_auth is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + try: + cast(Any, firebase_auth).delete_user(uid_value) + except Exception as auth_delete_error: + if _is_auth_user_not_found_error(auth_delete_error): + warnings.append(f"Authentication account already missing for {uid_value}.") + else: + raise HTTPException(status_code=500, detail=f"Failed to delete auth account for {uid_value}") + + profile_ref.delete() + succeeded += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="succeeded", + message="User account deleted.", + ) + ) + continue + + failed += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="failed", + message="Unsupported action.", + ) + ) + except HTTPException: + raise + except Exception as action_error: + failed += 1 + results.append( + AdminBulkActionResultItem( + uid=uid_value, + email=email_value, + status="failed", + message=f"Action failed: {action_error}", + ) + ) + + summary = { + "targeted": len(target_records), + "succeeded": succeeded, + "failed": failed, + "skipped": skipped, + "exported": len(exported_rows), + } + + _write_access_audit_log( + request, + action="admin_user_bulk_action", + status="success" if failed == 0 else "partial_success", + metadata={ + "action": action, + "summary": summary, + "dryRun": payload.dryRun, + "filtersApplied": payload.filters.model_dump() if payload.filters else None, + "userIdsCount": len(payload.userIds), + }, + ) + + export_payload: Optional[Dict[str, Any]] = None + if action == "export": + export_payload = { + "format": (payload.exportFormat or "csv").strip().lower(), + "rows": exported_rows, + } + + return AdminBulkActionResponse( + success=failed == 0, + action=action, + summary=summary, + results=results, + warnings=list(dict.fromkeys(warnings)), + export=export_payload, + ) + + +@app.post("/api/admin/users", response_model=AdminCreateUserResponse) +async def create_admin_user_and_notify( + request: Request, + payload: AdminCreateUserRequest, +): + """Create a single user account (Auth + Firestore) and send welcome credentials email.""" + user = get_current_user(request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + if not _firebase_ready: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + try: + provisioning_service = UserProvisioningService( + firebase_auth_module=firebase_auth, + firestore_module=firebase_firestore, + firestore_server_timestamp=FIRESTORE_SERVER_TIMESTAMP, + email_service=create_email_service_from_env(), + ) + + result: CreateUserAndNotifyResult = provisioning_service.create_user_and_notify( + AdminCreateUserInput( + name=payload.name, + email=payload.email, + password=payload.password, + confirm_password=payload.confirmPassword, + role=payload.role, + status=payload.status, + grade=payload.grade, + section=payload.section, + lrn=payload.lrn, + ) + ) + + email_error: Optional[Dict[str, Any]] = None + if result.email_result and not result.email_result.success: + email_error = { + "provider": result.email_result.provider, + "code": result.email_result.error_code, + "message": result.email_result.error_message, + "retryable": result.email_result.retryable, + } + + _write_access_audit_log( + request, + action="admin_user_create", + status="success" if result.email_sent else "partial_success", + metadata={ + "uid": result.uid, + "email": payload.email, + "role": payload.role, + "resultCode": result.result_code, + "emailSent": result.email_sent, + }, + ) + + return AdminCreateUserResponse( + success=True, + resultCode=result.result_code, + message=result.message, + userCreated=result.user_created, + emailSent=result.email_sent, + uid=result.uid, + warnings=result.warnings, + emailError=email_error, + ) + except UserProvisioningError as provisioning_error: + _write_access_audit_log( + request, + action="admin_user_create", + status="failed", + metadata={ + "email": payload.email, + "role": payload.role, + "errorCode": provisioning_error.code, + }, + ) + raise HTTPException(status_code=provisioning_error.status_code, detail=provisioning_error.message) + except HTTPException: + raise + except Exception as exc: + logger.error(f"Admin user creation failed unexpectedly: {exc}") + _write_access_audit_log( + request, + action="admin_user_create", + status="failed", + metadata={ + "email": payload.email, + "role": payload.role, + "errorCode": "unexpected_error", + }, + ) + raise HTTPException(status_code=500, detail="Failed to create user account") + + +@app.delete("/api/admin/users", response_model=AdminDeleteUserResponse) +async def delete_admin_user_account( + request: Request, + uid: str = Query(..., min_length=1), +): + """Delete a user account from Firebase Auth and Firestore profile records.""" + user = get_current_user(request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + if not _firebase_ready or firebase_auth is None or firebase_firestore is None: + raise HTTPException(status_code=503, detail="Authentication service unavailable") + + normalized_uid = str(uid or "").strip() + if not normalized_uid: + raise HTTPException(status_code=400, detail="User uid is required") + + if normalized_uid == user.uid: + raise HTTPException(status_code=400, detail="Admin users cannot delete their own account") + + auth_deleted = False + profile_deleted = False + warnings: List[str] = [] + + try: + try: + cast(Any, firebase_auth).delete_user(normalized_uid) + auth_deleted = True + except Exception as auth_delete_error: + if _is_auth_user_not_found_error(auth_delete_error): + warnings.append("Authentication account was already missing.") + else: + logger.error("Admin user delete failed in Auth for uid=%s: %s", normalized_uid, auth_delete_error) + raise HTTPException(status_code=500, detail="Failed to delete authentication account") + + try: + firestore_client = cast(Any, firebase_firestore).client() + profile_ref = firestore_client.collection("users").document(normalized_uid) + profile_snapshot = cast(Any, profile_ref.get()) + if _snapshot_exists(profile_snapshot): + profile_ref.delete() + profile_deleted = True + else: + warnings.append("User profile was already missing.") + except HTTPException: + raise + except Exception as profile_delete_error: + logger.error("Admin user delete failed in Firestore for uid=%s: %s", normalized_uid, profile_delete_error) + raise HTTPException(status_code=500, detail="Failed to delete user profile") + + status_label = "success" if (auth_deleted or profile_deleted) else "noop" + _write_access_audit_log( + request, + action="admin_user_delete", + status=status_label, + metadata={ + "uid": normalized_uid, + "authDeleted": auth_deleted, + "profileDeleted": profile_deleted, + "warnings": list(dict.fromkeys(warnings)), + }, + ) + + if auth_deleted and profile_deleted: + message = "User account deleted from authentication and profile records." + elif auth_deleted: + message = "Authentication account deleted. Profile record was already missing." + elif profile_deleted: + message = "Profile record deleted. Authentication account was already missing." + else: + message = "No matching user records were found to delete." + + return AdminDeleteUserResponse( + success=True, + uid=normalized_uid, + authDeleted=auth_deleted, + profileDeleted=profile_deleted, + message=message, + warnings=list(dict.fromkeys(warnings)), + ) + except HTTPException as http_exc: + _write_access_audit_log( + request, + action="admin_user_delete", + status="failed", + metadata={ + "uid": normalized_uid, + "authDeleted": auth_deleted, + "profileDeleted": profile_deleted, + "errorStatus": http_exc.status_code, + }, + ) + raise + except Exception as exc: + logger.error(f"Admin user deletion failed unexpectedly for {normalized_uid}: {exc}") + _write_access_audit_log( + request, + action="admin_user_delete", + status="failed", + metadata={ + "uid": normalized_uid, + "authDeleted": auth_deleted, + "profileDeleted": profile_deleted, + "errorCode": "unexpected_error", + }, + ) + raise HTTPException(status_code=500, detail="Failed to delete user account") + + +@app.get("/api/upload/class-records/risk-refresh/recent") +async def get_recent_risk_refresh_status( + request: Request, + classSectionId: Optional[str] = Query(default=None), + limit: int = Query(default=10, ge=1, le=50), +): + """Return lightweight monitoring view for recent post-import risk refresh jobs.""" + try: + user = get_current_user(request) + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + normalized_class_section_id = (classSectionId or "").strip() or None + query = ( + firebase_firestore.client() + .collection("riskRefreshJobs") + .where("teacherId", "==", user.uid) + ) + if normalized_class_section_id: + query = query.where("classSectionId", "==", normalized_class_section_id) + + warnings: List[str] = [] + try: + docs = ( + query + .order_by("updatedAt", direction=FIRESTORE_QUERY_DESCENDING) + .limit(limit) + .stream() + ) + except Exception: + warnings.append("Risk refresh monitor used fallback query path without ordering.") + docs = query.limit(limit).stream() + + jobs: List[Dict[str, Any]] = [] + for doc in docs: + data = doc.to_dict() or {} + jobs.append( + { + "refreshId": str(data.get("refreshId") or doc.id), + "status": str(data.get("status") or "unknown"), + "studentsQueued": int(data.get("studentsQueued") or 0), + "classSectionId": data.get("classSectionId"), + "queuedAtEpoch": data.get("queuedAtEpoch"), + "startedAtEpoch": data.get("startedAtEpoch"), + "completedAtEpoch": data.get("completedAtEpoch"), + "durationMs": data.get("durationMs"), + "updatedAtIso": data.get("updatedAtIso"), + "metadata": data.get("metadata") or {}, + } + ) + + stats_doc = cast(Any, ( + firebase_firestore.client() + .collection("riskRefreshStats") + .document(user.uid) + .get() + )) + stats_data = _snapshot_to_dict(stats_doc) if _snapshot_exists(stats_doc) else {} + stats = { + "queuedCount": int(stats_data.get("queuedCount", 0) or 0), + "successCount": int(stats_data.get("successCount", 0) or 0), + "failedCount": int(stats_data.get("failedCount", 0) or 0), + "lastRefreshId": stats_data.get("lastRefreshId"), + "lastStatus": stats_data.get("lastStatus"), + "lastStudentsQueued": int(stats_data.get("lastStudentsQueued", 0) or 0), + "lastQueuedAtEpoch": stats_data.get("lastQueuedAtEpoch"), + "lastStartedAtEpoch": stats_data.get("lastStartedAtEpoch"), + "lastCompletedAtEpoch": stats_data.get("lastCompletedAtEpoch"), + "lastDurationMs": stats_data.get("lastDurationMs"), + "updatedAtIso": stats_data.get("updatedAtIso"), + } + + response_payload = { + "success": True, + "classSectionId": normalized_class_section_id, + "stats": stats, + "jobs": jobs, + "warnings": warnings, + } + + _write_access_audit_log( + request, + action="risk_refresh_monitor_read", + status="success", + class_section_id=normalized_class_section_id, + metadata={ + "requestedLimit": limit, + "returnedJobs": len(jobs), + "warningsCount": len(warnings), + }, + ) + + return response_payload + except HTTPException: + raise + except Exception as e: + logger.error(f"Risk refresh monitor lookup error: {e}") + raise HTTPException(status_code=500, detail=f"Risk refresh monitor lookup error: {str(e)}") + + +@app.post("/api/upload/class-records") +async def upload_class_records( + request: Request, + file: Optional[UploadFile] = File(default=None), + files: Optional[List[UploadFile]] = File(default=None), + classSectionId: Optional[str] = Form(default=None), + className: Optional[str] = Form(default=None), + datasetIntent: str = Form(default="synthetic_student_records"), +): + """Upload and parse class records (CSV, Excel, PDF) with AI column detection""" + try: + import pandas as pd # type: ignore[import-not-found] + + enforce_rate_limit(request, "upload_class_records", UPLOAD_RATE_LIMIT_PER_MIN, 60) + + uploads = _resolve_uploaded_files(file=file, files=files) + normalized_dataset_intent = (datasetIntent or "").strip() or "synthetic_student_records" + if normalized_dataset_intent not in SUPPORTED_DATASET_INTENTS: + raise HTTPException( + status_code=400, + detail=( + "Unsupported datasetIntent. Allowed values: " + + ", ".join(sorted(SUPPORTED_DATASET_INTENTS)) + ), + ) + + user = get_current_user(request) + ( + resolved_upload_class_section_id, + resolved_upload_class_name, + resolved_upload_grade, + resolved_upload_section, + ) = _resolve_import_class_context( + class_section_id=(classSectionId or "").strip() or None, + class_name=(className or "").strip() or None, + ) + upload_class_metadata = _build_class_metadata( + class_section_id=resolved_upload_class_section_id, + class_name=resolved_upload_class_name, + grade=resolved_upload_grade, + section=resolved_upload_section, + owner_teacher_id=user.uid, + owner_teacher_name=user.email, + adviser_teacher_id=user.uid, + adviser_teacher_name=user.email, + manager_id=user.uid, + manager_name=user.email, + ) + + all_students: List[Dict[str, Any]] = [] + all_unknown_columns: Set[str] = set() + all_warnings: List[str] = [] + all_row_warnings: List[Dict[str, Any]] = [] + all_rejected_rows: List[Dict[str, Any]] = [] + all_column_interpretations: List[Dict[str, Any]] = [] + aggregate_interpretation = { + "scoringColumns": 0, + "displayColumns": 0, + "storageOnlyColumns": 0, + "lowConfidenceColumns": 0, + "domainMismatchWarnings": 0, + } + aggregate_dedup = {"inserted": 0, "updated": 0} + interpreted_rows_total = 0 + rejected_rows_total = 0 + inferred_rows_total = 0 + fallback_inference_rows_total = 0 + per_file_results: List[Dict[str, Any]] = [] + + for upload in uploads: + filename = upload.filename or "" + ext = os.path.splitext(filename)[1].lower() + file_warnings: List[str] = [] + file_row_warnings: List[Dict[str, Any]] = [] + file_rejected_rows: List[Dict[str, Any]] = [] + file_students: List[Dict[str, Any]] = [] + file_unknown_columns: List[str] = [] + file_column_mapping: Dict[str, str] = {} + file_column_mapping_source: Dict[str, str] = {} + file_column_interpretations: List[Dict[str, Any]] = [] + file_interpretation_summary: Dict[str, Any] = { + "scoringColumns": 0, + "displayColumns": 0, + "storageOnlyColumns": 0, + "lowConfidenceColumns": 0, + "domainMismatchWarnings": 0, + } + file_dedup = {"inserted": 0, "updated": 0} + file_import_id: Optional[str] = None + file_interpreted_rows = 0 + file_rejected_rows_count = 0 + file_inferred_rows = 0 + file_fallback_inference_rows = 0 + + try: + if ext not in ALLOWED_UPLOAD_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Unsupported file format: {filename}. Use .csv, .xlsx, .xls, or .pdf", + ) + + if (upload.content_type or "").lower() not in ALLOWED_UPLOAD_MIME_TYPES: + raise HTTPException( + status_code=400, + detail=f"Unsupported content type: {upload.content_type}", + ) + + contents = await upload.read(UPLOAD_MAX_BYTES + 1) + if len(contents) > UPLOAD_MAX_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large. Max allowed size is {UPLOAD_MAX_BYTES // (1024 * 1024)} MB.", + ) + + df = None + + if ext == ".csv": + df = pd.read_csv(io.BytesIO(contents), on_bad_lines="skip") + elif ext in {".xlsx", ".xls"}: + df = pd.read_excel(io.BytesIO(contents)) + elif ext == ".pdf": + import pdfplumber + with pdfplumber.open(io.BytesIO(contents)) as pdf: + if len(pdf.pages) > UPLOAD_MAX_PDF_PAGES: + raise HTTPException( + status_code=413, + detail=f"PDF has too many pages. Max allowed pages: {UPLOAD_MAX_PDF_PAGES}", + ) + tables = [] + for page in pdf.pages: + page_tables = page.extract_tables() + if page_tables: + tables.extend(page_tables) + if tables and len(tables[0]) > 1: + df = pd.DataFrame(tables[0][1:], columns=tables[0][0]) + else: + raise HTTPException(status_code=400, detail="No tables found in PDF") + else: + raise HTTPException( + status_code=400, + detail=f"Unsupported file format: {filename}. Use .csv, .xlsx, or .pdf", + ) + + if df is None or df.empty: + raise HTTPException(status_code=400, detail="No data found in uploaded file") + + if df.shape[0] > UPLOAD_MAX_ROWS: + raise HTTPException( + status_code=413, + detail=f"Too many rows ({df.shape[0]}). Max allowed: {UPLOAD_MAX_ROWS}", + ) + + if df.shape[1] > UPLOAD_MAX_COLS: + raise HTTPException( + status_code=413, + detail=f"Too many columns ({df.shape[1]}). Max allowed: {UPLOAD_MAX_COLS}", + ) + + file_hash = hashlib.sha256(contents).hexdigest() + + # AI-powered column mapping + columns_text = ", ".join(df.columns.tolist()) + + prompt = f"""I have a spreadsheet with these columns: {columns_text} + +Map each column to one of these standard fields (respond as JSON only): +- name (student full name) +- lrn (learner reference number) +- email (email address) +- engagementScore (engagement percentage) +- avgQuizScore (average quiz/test score) +- attendance (attendance percentage) + +If a column doesn't match any field, skip it. Respond ONLY with a JSON object mapping original column names to field names. Example: {{\"Student Name\": \"name\", \"LRN\": \"lrn\"}}""" + + mapping_text = "" + try: + mapping_text = await call_hf_chat_async( + messages=[{"role": "user", "content": prompt}], + max_tokens=300, + temperature=0.1, + ) + json_start = mapping_text.find("{") + json_end = mapping_text.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + ai_mapping = _sanitize_column_mapping(json.loads(mapping_text[json_start:json_end])) + file_column_mapping = dict(ai_mapping) + file_column_mapping_source = {col: "ai" for col in ai_mapping.keys()} + else: + file_column_mapping = {} + file_warnings.append("AI mapper returned no JSON; fallback mapper was used.") + except Exception: + file_column_mapping = {} + file_warnings.append("AI mapper failed; fallback mapper was used.") + + fallback_mapping = _fallback_column_mapping(df.columns.tolist()) + for col, field in fallback_mapping.items(): + if col not in file_column_mapping: + file_column_mapping[col] = field + file_column_mapping_source[col] = "fallback" + + file_column_mapping = _sanitize_column_mapping(file_column_mapping) + + missing_core_fields, missing_identity = _validate_class_record_mapping(file_column_mapping) + if missing_core_fields: + file_warnings.append( + "Missing preferred educational columns after mapping: " + + ", ".join(missing_core_fields) + + ". Import will continue with metric defaults where needed." + ) + if missing_identity: + raise HTTPException( + status_code=400, + detail=( + "Import requires at least one student identity column (lrn or email) to avoid record collisions." + ), + ) + + non_education_matches = _detect_non_education_signals(df.columns.tolist()) + if non_education_matches: + file_warnings.append( + "Potential non-education columns detected; these columns will be stored but excluded from scoring: " + + ", ".join(sorted(non_education_matches.keys())[:12]) + ) + + file_column_interpretations = _build_column_interpretations( + columns=df.columns.tolist(), + mapping=file_column_mapping, + mapping_source=file_column_mapping_source, + non_education_matches=non_education_matches, + ) + file_interpretation_summary = { + "scoringColumns": sum(1 for item in file_column_interpretations if item.get("usagePolicy") == "scoring"), + "displayColumns": sum(1 for item in file_column_interpretations if item.get("usagePolicy") == "display"), + "storageOnlyColumns": sum(1 for item in file_column_interpretations if item.get("usagePolicy") == "storage_only"), + "lowConfidenceColumns": sum(1 for item in file_column_interpretations if item.get("confidenceBand") == "low"), + "domainMismatchWarnings": len(non_education_matches), + } + + normalized_result = _normalize_class_records( + df, + file_name=filename, + file_hash=file_hash, + column_mapping=file_column_mapping, + ) + file_students = normalized_result["rows"] + file_row_warnings = normalized_result["rowWarnings"] + file_rejected_rows = normalized_result.get("rejectedRows") or [] + file_unknown_columns = normalized_result["unknownColumns"] + file_interpreted_rows = int(normalized_result.get("interpretedRows") or len(file_students)) + file_rejected_rows_count = int(normalized_result.get("rejectedRowsCount") or len(file_rejected_rows)) + file_inferred_rows = int(normalized_result.get("inferredRows") or 0) + file_fallback_inference_rows = int(normalized_result.get("fallbackInferenceRows") or 0) + + persistence_result = _persist_class_record_import_artifact( + request, + file_hash=file_hash, + file_name=filename, + file_type=ext.replace(".", ""), + column_mapping=file_column_mapping, + normalized_rows=file_students, + row_warnings=file_row_warnings, + unknown_columns=file_unknown_columns, + parse_warnings=file_warnings, + dataset_intent=normalized_dataset_intent, + column_interpretations=file_column_interpretations, + interpretation_summary=file_interpretation_summary, + class_section_id=classSectionId, + class_name=className, + ) + if persistence_result.get("warning"): + file_warnings.append(str(persistence_result["warning"])) + + file_status = "success" + if file_row_warnings or file_warnings: + file_status = "partial_success" + if not file_students: + file_status = "failed" + + file_import_id = persistence_result.get("importId") + file_dedup = persistence_result.get("dedup") or {"inserted": 0, "updated": 0} + except HTTPException as file_exc: + file_status = "failed" + file_warnings.append(str(file_exc.detail)) + except Exception as file_exc: + logger.error(f"Class records processing failed for {filename}: {file_exc}") + file_status = "failed" + file_warnings.append(f"Unexpected processing error: {str(file_exc)}") + + per_file_result = { + "fileName": filename, + "fileType": ext.replace(".", ""), + "status": file_status, + "students": file_students, + "totalRows": len(file_students), + "columnMapping": file_column_mapping, + "unknownColumns": file_unknown_columns, + "warnings": file_warnings, + "rowWarnings": file_row_warnings, + "rejectedRows": file_rejected_rows, + "datasetIntent": normalized_dataset_intent, + "columnInterpretations": file_column_interpretations, + "interpretationSummary": file_interpretation_summary, + "classSectionId": resolved_upload_class_section_id, + "className": resolved_upload_class_name, + "classMetadata": upload_class_metadata, + "importId": file_import_id, + "persisted": bool(file_import_id), + "dedup": file_dedup, + "interpretedRows": file_interpreted_rows, + "rejectedRowsCount": file_rejected_rows_count, + "inferredRows": file_inferred_rows, + "fallbackInferenceRows": file_fallback_inference_rows, + } + per_file_results.append(per_file_result) + + all_students.extend(file_students) + all_unknown_columns.update(file_unknown_columns) + all_column_interpretations.extend(file_column_interpretations) + all_rejected_rows.extend( + [ + { + "row": item.get("row"), + "reason": f"{filename}: {item.get('reason', '')}", + } + for item in file_rejected_rows + ] + ) + aggregate_interpretation["scoringColumns"] += int(file_interpretation_summary.get("scoringColumns", 0) or 0) + aggregate_interpretation["displayColumns"] += int(file_interpretation_summary.get("displayColumns", 0) or 0) + aggregate_interpretation["storageOnlyColumns"] += int(file_interpretation_summary.get("storageOnlyColumns", 0) or 0) + aggregate_interpretation["lowConfidenceColumns"] += int(file_interpretation_summary.get("lowConfidenceColumns", 0) or 0) + aggregate_interpretation["domainMismatchWarnings"] += int(file_interpretation_summary.get("domainMismatchWarnings", 0) or 0) + aggregate_dedup["inserted"] += int(file_dedup.get("inserted", 0) or 0) + aggregate_dedup["updated"] += int(file_dedup.get("updated", 0) or 0) + interpreted_rows_total += file_interpreted_rows + rejected_rows_total += file_rejected_rows_count + inferred_rows_total += file_inferred_rows + fallback_inference_rows_total += file_fallback_inference_rows + all_warnings.extend([f"{filename}: {warning}" for warning in file_warnings]) + all_row_warnings.extend( + [ + { + "row": warning.get("row"), + "warning": f"{filename}: {warning.get('warning', '')}", + } + for warning in file_row_warnings + ] + ) + + first_file_with_mapping = next( + (f for f in per_file_results if f.get("columnMapping")), + None, + ) + first_file_with_import = next( + (f for f in per_file_results if f.get("importId")), + None, + ) + successful_files = sum(1 for f in per_file_results if f.get("status") in {"success", "partial_success"}) + failed_files = len(per_file_results) - successful_files + overall_success = successful_files > 0 + risk_refresh = _queue_post_import_risk_refresh( + request, + students=all_students, + column_mapping=(first_file_with_mapping or {}).get("columnMapping") or {}, + class_section_id=(classSectionId or "").strip() or None, + ) + dashboard_sync = _sync_imported_students_to_teacher_dashboard( + request, + normalized_rows=all_students, + class_section_id=(classSectionId or "").strip() or None, + class_name=(className or "").strip() or None, + ) + if dashboard_sync.get("warning"): + all_warnings.append(str(dashboard_sync["warning"])) + + persisted_rows = int(aggregate_dedup.get("inserted", 0) or 0) + int(aggregate_dedup.get("updated", 0) or 0) + rejected_reason_counts = dict(Counter(item.get("reason", "unknown") for item in all_rejected_rows)) + inferred_coverage_pct = round((float(inferred_rows_total) / float(max(interpreted_rows_total, 1))) * 100.0, 1) + partial_success_files = sum(1 for f in per_file_results if f.get("status") == "partial_success") + + response_payload = { + "success": overall_success, + "classMetadata": upload_class_metadata, + "students": all_students, + "columnMapping": (first_file_with_mapping or {}).get("columnMapping") or {}, + "datasetIntent": normalized_dataset_intent, + "totalRows": len(all_students), + "interpretedRows": interpreted_rows_total, + "rejectedRows": rejected_rows_total, + "rejectedRowDetails": all_rejected_rows, + "rejectedReasons": rejected_reason_counts, + "persistedRows": persisted_rows, + "inferredStateCoverage": { + "inferredRows": inferred_rows_total, + "interpretedRows": interpreted_rows_total, + "fallbackRows": fallback_inference_rows_total, + "coveragePct": inferred_coverage_pct, + }, + "unknownColumns": sorted(all_unknown_columns), + "columnInterpretations": all_column_interpretations, + "interpretationSummary": aggregate_interpretation, + "warnings": all_warnings, + "rowWarnings": all_row_warnings, + "importId": (first_file_with_import or {}).get("importId"), + "persisted": bool(first_file_with_import and first_file_with_import.get("importId")), + "dedup": aggregate_dedup, + "files": per_file_results, + "summary": { + "totalFiles": len(per_file_results), + "successfulFiles": successful_files, + "partialSuccessFiles": partial_success_files, + "failedFiles": failed_files, + }, + "riskRefresh": risk_refresh, + "dashboardSync": dashboard_sync, + } + + _write_access_audit_log( + request, + action="class_records_upload", + status="success" if overall_success else "partial_failure", + class_section_id=(classSectionId or "").strip() or None, + metadata={ + "totalFiles": len(per_file_results), + "successfulFiles": successful_files, + "failedFiles": failed_files, + "persisted": bool(first_file_with_import and first_file_with_import.get("importId")), + "students": len(all_students), + "interpretedRows": interpreted_rows_total, + "rejectedRows": rejected_rows_total, + "persistedRows": persisted_rows, + "inferredRows": inferred_rows_total, + "fallbackInferenceRows": fallback_inference_rows_total, + "datasetIntent": normalized_dataset_intent, + "storageOnlyColumns": int(aggregate_interpretation.get("storageOnlyColumns", 0) or 0), + "domainMismatchWarnings": int(aggregate_interpretation.get("domainMismatchWarnings", 0) or 0), + "dashboardSync": bool(dashboard_sync.get("synced")), + "dashboardCreatedStudents": int(dashboard_sync.get("createdStudents") or 0), + "dashboardUpdatedStudents": int(dashboard_sync.get("updatedStudents") or 0), + }, + ) + + return response_payload + + except HTTPException: + raise + except Exception as e: + logger.error(f"Upload error: {e}") + raise HTTPException(status_code=500, detail=f"File upload error: {str(e)}") + + +def _split_material_sections(text: str, max_sections: int = 20) -> List[Dict[str, str]]: + blocks = [block.strip() for block in re.split(r"\n\s*\n", text) if block.strip()] + sections: List[Dict[str, str]] = [] + for idx, block in enumerate(blocks[:max_sections]): + lines = [line.strip() for line in block.splitlines() if line.strip()] + if not lines: + continue + title_candidate = lines[0][:80] + if len(lines) > 1 and len(title_candidate.split()) <= 12: + title = title_candidate + body = " ".join(lines[1:]) + else: + title = f"Section {idx + 1}" + body = " ".join(lines) + preview = re.sub(r"\s+", " ", body).strip()[:220] + sections.append( + { + "sectionId": f"section_{idx + 1}", + "title": title, + "preview": preview, + } + ) + return sections + + +def _fallback_topic_extraction(text: str, max_topics: int = 8) -> List[Dict[str, Any]]: + stop_words = { + "about", "after", "again", "algebra", "also", "because", "before", "being", "between", + "could", "course", "each", "from", "have", "into", "lesson", "math", "module", "other", + "should", "their", "there", "these", "they", "this", "those", "topic", "topics", "using", + "will", "with", "your", + } + words = re.findall(r"\b[a-zA-Z][a-zA-Z\-]{3,}\b", text.lower()) + filtered = [w for w in words if w not in stop_words] + if not filtered: + return [] + + counts = Counter(filtered) + topics: List[Dict[str, Any]] = [] + for idx, (word, _) in enumerate(counts.most_common(max_topics)): + title = word.replace("-", " ").title() + topics.append( + { + "topicId": f"topic_{idx + 1}", + "title": title, + "description": f"Coverage area inferred from uploaded material around '{title}'.", + "prerequisiteTopics": [], + } + ) + return topics + + +def _compute_material_source_legitimacy( + *, + file_type: str, + file_hash: str, + extracted_text: str, + topics: List[Dict[str, Any]], + warnings: List[str], +) -> Dict[str, Any]: + issues: List[str] = [] + evidence_checked = [ + "file_type", + "file_hash", + "text_length", + "topic_count", + "extraction_warnings", + ] + + score = 1.0 + normalized_type = (file_type or "").strip().lower() + if normalized_type not in {"pdf", "docx", "txt"}: + issues.append(f"Unsupported source type '{normalized_type}'.") + score -= 0.7 + + if not (file_hash or "").strip(): + issues.append("Source file hash is missing.") + score -= 0.6 + + text_length = len(extracted_text or "") + if text_length < LESSON_SOURCE_MIN_TEXT_LENGTH: + issues.append( + f"Extracted text is too short ({text_length} chars); minimum is {LESSON_SOURCE_MIN_TEXT_LENGTH}." + ) + score -= 0.5 + + topic_count = len(topics or []) + if topic_count < LESSON_SOURCE_MIN_TOPICS: + issues.append( + f"Insufficient extracted topics ({topic_count}); minimum is {LESSON_SOURCE_MIN_TOPICS}." + ) + score -= 0.5 + + warning_hits = [w for w in (warnings or []) if "fallback" in str(w).lower() or "failed" in str(w).lower()] + if warning_hits: + issues.append("Source extraction had fallback/failure warnings that require review.") + score -= min(0.4, 0.15 * len(warning_hits)) + + score = max(0.0, min(1.0, score)) + if score >= 0.75: + status = "verified" + elif score >= 0.45: + status = "review_required" + else: + status = "rejected" + + return { + "status": status, + "score": round(score, 3), + "issues": issues, + "evidenceChecked": evidence_checked, + "checkedAtIso": datetime.now(timezone.utc).isoformat(), + } + + +def _evaluate_lesson_source_legitimacy( + imported_topics_payload: Dict[str, Any], + *, + allow_review_sources: bool, +) -> Dict[str, Any]: + materials = imported_topics_payload.get("materials") or [] + verified_materials = 0 + review_materials = 0 + rejected_materials = 0 + issues: List[str] = [] + evidence_checked = ["artifact_legitimacy", "material_metadata", "topic_provenance"] + scores: List[float] = [] + + for material in materials: + legitimacy = material.get("sourceLegitimacy") or {} + status = str(legitimacy.get("status") or "review_required").strip().lower() + score = float(legitimacy.get("score") or 0.0) + scores.append(max(0.0, min(1.0, score))) + + if status == "verified": + verified_materials += 1 + elif status == "review_required": + review_materials += 1 + else: + rejected_materials += 1 + issues.extend([str(x) for x in (legitimacy.get("issues") or []) if str(x).strip()]) + + average_score = round(sum(scores) / len(scores), 3) if scores else 0.0 + + if rejected_materials > 0: + status = "rejected" + elif review_materials > 0: + status = "review_required" + else: + status = "verified" if verified_materials > 0 else "review_required" + + if status == "review_required" and not allow_review_sources: + issues.append("Source legitimacy requires review. Enable allowReviewSources to proceed.") + if status == "rejected": + issues.append("One or more imported sources failed legitimacy checks.") + + return { + "status": status, + "score": average_score, + "verifiedMaterials": verified_materials, + "reviewMaterials": review_materials, + "rejectedMaterials": rejected_materials, + "evidenceChecked": evidence_checked, + "issues": sorted(list({issue for issue in issues if issue.strip()})), + } + + +def _persist_course_material_artifact( + request: Request, + *, + file_hash: str, + file_name: str, + file_type: str, + extracted_text: str, + sections: List[Dict[str, Any]], + topics: List[Dict[str, Any]], + warnings: List[str], + class_section_id: Optional[str] = None, + class_name: Optional[str] = None, +) -> Dict[str, Any]: + if not (_firebase_ready and firebase_firestore): + return { + "persisted": False, + "materialId": None, + "warning": "Firestore unavailable; material was not persisted.", + "sourceLegitimacy": { + "status": "review_required", + "score": 0.0, + "issues": ["Firestore unavailable; source legitimacy metadata not persisted."], + "evidenceChecked": [], + }, + } + + user = get_current_user(request) + normalized_class_section_id = (class_section_id or "").strip() or None + normalized_class_name = (class_name or "").strip() or None + dedup_seed = f"{user.uid}|{normalized_class_section_id or 'global'}|{file_hash}" + material_id = hashlib.sha1(dedup_seed.encode("utf-8")).hexdigest()[:28] + source_legitimacy = _compute_material_source_legitimacy( + file_type=file_type, + file_hash=file_hash, + extracted_text=extracted_text, + topics=topics, + warnings=warnings, + ) + + doc_payload: Dict[str, Any] = { + "materialId": material_id, + "teacherId": user.uid, + "teacherEmail": user.email, + "fileName": file_name, + "fileType": file_type, + "fileHash": file_hash, + "extractedTextLength": len(extracted_text), + "extractedTextPreview": extracted_text[:3000], + "sections": sections, + "topics": topics, + "warnings": warnings, + "sourceLegitimacy": source_legitimacy, + "source": "api_upload_course_materials", + "retentionDays": IMPORT_RETENTION_DAYS, + "expiresAtEpoch": _artifact_expiry_epoch(), + "updatedAt": FIRESTORE_SERVER_TIMESTAMP, + } + if normalized_class_section_id: + doc_payload["classSectionId"] = normalized_class_section_id + if normalized_class_name: + doc_payload["className"] = normalized_class_name + + materials_ref = firebase_firestore.client().collection("courseMaterials").document(material_id) + existing = cast(Any, materials_ref.get()) + if not _snapshot_exists(existing): + doc_payload["createdAt"] = FIRESTORE_SERVER_TIMESTAMP + + materials_ref.set(doc_payload, merge=True) + return { + "persisted": True, + "materialId": material_id, + "warning": None, + "sourceLegitimacy": source_legitimacy, + } + + +def _load_persisted_course_material_topics( + request: Request, + *, + class_section_id: Optional[str] = None, + material_id: Optional[str] = None, + limit_materials: int = 20, +) -> Dict[str, Any]: + if not (_firebase_ready and firebase_firestore): + return { + "topics": [], + "materials": [], + "warnings": ["Firestore unavailable; imported topic lookup skipped."], + } + + user = get_current_user(request) + normalized_class_section_id = (class_section_id or "").strip() or None + normalized_material_id = (material_id or "").strip() or None + + try: + query = ( + firebase_firestore.client() + .collection("courseMaterials") + .where("teacherId", "==", user.uid) + ) + if normalized_class_section_id: + query = query.where("classSectionId", "==", normalized_class_section_id) + if normalized_material_id: + query = query.where("materialId", "==", normalized_material_id) + except Exception as e: + warning = ( + "Firestore ADC is not configured; imported topic lookup skipped." + if _is_adc_missing_error(e) + else "Firestore lookup unavailable; imported topic lookup skipped." + ) + _warn_firestore_topics_lookup_once(f"{warning} Error: {e}") + return { + "topics": [], + "materials": [], + "warnings": [warning], + } + + warnings: List[str] = [] + try: + docs = ( + query + .order_by("updatedAt", direction=FIRESTORE_QUERY_DESCENDING) + .limit(limit_materials) + .stream() + ) + except Exception as stream_error: + # Fallback for index limitations on combined where+order queries. + warnings.append("Topic lookup used fallback query path without ordering.") + try: + docs = query.limit(limit_materials).stream() + except Exception as fallback_error: + chosen_error = fallback_error or stream_error + warning = ( + "Firestore ADC is not configured; imported topic lookup skipped." + if _is_adc_missing_error(cast(Exception, chosen_error)) + else "Firestore lookup unavailable; imported topic lookup skipped." + ) + _warn_firestore_topics_lookup_once(f"{warning} Error: {chosen_error}") + return { + "topics": [], + "materials": [], + "warnings": [warning], + } + + materials: List[Dict[str, Any]] = [] + deduped_topics: Dict[str, Dict[str, Any]] = {} + expired_count = 0 + for doc in docs: + data = doc.to_dict() or {} + if _is_artifact_expired(data): + expired_count += 1 + continue + + doc_material_id = str(data.get("materialId") or doc.id) + doc_file_name = str(data.get("fileName") or "") + doc_class_section_id = data.get("classSectionId") + doc_class_name = data.get("className") + topics = data.get("topics") or [] + + materials.append( + { + "materialId": doc_material_id, + "fileName": doc_file_name, + "fileType": str(data.get("fileType") or ""), + "fileHash": str(data.get("fileHash") or ""), + "extractedTextLength": int(data.get("extractedTextLength") or 0), + "classSectionId": doc_class_section_id, + "className": doc_class_name, + "topicsCount": len(topics), + "sourceLegitimacy": data.get("sourceLegitimacy") or { + "status": "review_required", + "score": 0.0, + "issues": ["Missing source legitimacy metadata."], + "evidenceChecked": [], + }, + } + ) + + for idx, topic in enumerate(topics): + title = str(topic.get("title") or "").strip() + if not title: + continue + + topic_id = str(topic.get("topicId") or f"topic_{idx + 1}") + description = str(topic.get("description") or "").strip() + prerequisite_topics = [ + str(item).strip() + for item in (topic.get("prerequisiteTopics") or []) + if str(item).strip() + ] + source_files = [ + str(item).strip() + for item in (topic.get("sourceFiles") or [doc_file_name]) + if str(item).strip() + ] + + dedup_key = re.sub(r"\s+", " ", title.lower()).strip() + if dedup_key not in deduped_topics: + deduped_topics[dedup_key] = { + "topicId": topic_id, + "title": title, + "description": description, + "prerequisiteTopics": prerequisite_topics, + "sourceFiles": source_files, + "materialId": doc_material_id, + "sourceFile": source_files[0] if source_files else doc_file_name, + "sectionId": None, + "classSectionId": doc_class_section_id, + "className": doc_class_name, + } + + if expired_count > 0: + warnings.append(f"{expired_count} expired course-material artifact(s) were excluded by retention policy.") + + return { + "topics": list(deduped_topics.values()), + "materials": materials, + "warnings": warnings, + } + + +@app.post("/api/upload/course-materials") +async def upload_course_materials( + request: Request, + file: Optional[UploadFile] = File(default=None), + files: Optional[List[UploadFile]] = File(default=None), + classSectionId: Optional[str] = Form(default=None), + className: Optional[str] = Form(default=None), +): + """Upload and extract curriculum topics from course materials (PDF, DOCX, TXT).""" + try: + enforce_rate_limit(request, "upload_course_materials", UPLOAD_RATE_LIMIT_PER_MIN, 60) + + uploads = _resolve_uploaded_files(file=file, files=files) + normalized_class_section_id = (classSectionId or "").strip() or None + normalized_class_name = (className or "").strip() or None + + all_sections: List[Dict[str, Any]] = [] + all_topics: List[Dict[str, Any]] = [] + all_warnings: List[str] = [] + per_file_results: List[Dict[str, Any]] = [] + + for upload in uploads: + filename = upload.filename or "" + ext = os.path.splitext(filename)[1].lower() + file_warnings: List[str] = [] + file_sections: List[Dict[str, Any]] = [] + file_topics: List[Dict[str, Any]] = [] + file_hash: Optional[str] = None + material_id: Optional[str] = None + persisted = False + source_legitimacy: Dict[str, Any] = { + "status": "review_required", + "score": 0.0, + "issues": [], + "evidenceChecked": [], + } + extracted_text_length = 0 + + try: + if ext not in ALLOWED_COURSE_MATERIAL_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Unsupported file format: {filename}. Use .pdf, .docx, or .txt", + ) + + if (upload.content_type or "").lower() not in ALLOWED_COURSE_MATERIAL_MIME_TYPES: + raise HTTPException( + status_code=400, + detail=f"Unsupported content type: {upload.content_type}", + ) + + contents = await upload.read(UPLOAD_MAX_BYTES + 1) + if len(contents) > UPLOAD_MAX_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large. Max allowed size is {UPLOAD_MAX_BYTES // (1024 * 1024)} MB.", + ) + + extracted_text = "" + file_hash = hashlib.sha256(contents).hexdigest() + + if ext == ".pdf": + import pdfplumber + + with pdfplumber.open(io.BytesIO(contents)) as pdf: + if len(pdf.pages) > UPLOAD_MAX_PDF_PAGES: + raise HTTPException( + status_code=413, + detail=f"PDF has too many pages. Max allowed pages: {UPLOAD_MAX_PDF_PAGES}", + ) + + page_texts: List[str] = [] + for page in pdf.pages: + text = page.extract_text() or "" + if text.strip(): + page_texts.append(text) + extracted_text = "\n\n".join(page_texts) + elif ext == ".docx": + import importlib + + docx_module = importlib.import_module("docx") + doc = docx_module.Document(io.BytesIO(contents)) + paragraphs = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] + extracted_text = "\n\n".join(paragraphs) + elif ext == ".txt": + extracted_text = contents.decode("utf-8", errors="ignore") + + extracted_text = re.sub(r"\s+", " ", extracted_text).strip() + if not extracted_text: + raise HTTPException( + status_code=400, + detail="No readable text found in uploaded course material", + ) + + extracted_text_length = len(extracted_text) + sections = _split_material_sections(extracted_text) + prompt_excerpt = extracted_text[:7000] + topic_prompt = f"""Extract classroom math curriculum topics from this course material text. + +Return JSON only in this exact shape: +{{ + "topics": [ + {{ + "title": "...", + "description": "...", + "prerequisiteTopics": ["..."] + }} + ] +}} + +Rules: +- Keep topics concise and teacher-friendly. +- Include at most 10 topics. +- Use empty prerequisiteTopics when unknown. + +TEXT: +{prompt_excerpt} +""" + + extracted_topics: List[Dict[str, Any]] = [] + try: + topic_text = await call_hf_chat_async( + messages=[{"role": "user", "content": topic_prompt}], + max_tokens=700, + temperature=0.1, + ) + json_start = topic_text.find("{") + json_end = topic_text.rfind("}") + 1 + topic_payload: Dict[str, Any] = {} + if json_start >= 0 and json_end > json_start: + topic_payload = json.loads(topic_text[json_start:json_end]) + + for idx, topic in enumerate((topic_payload.get("topics") or [])[:10]): + title = str(topic.get("title", "")).strip() + if not title: + continue + desc = str(topic.get("description", "")).strip() or f"Curriculum content related to {title}." + prereq_raw = topic.get("prerequisiteTopics") or [] + prereq = [str(p).strip() for p in prereq_raw if str(p).strip()] + extracted_topics.append( + { + "topicId": f"topic_{idx + 1}", + "title": title, + "description": desc, + "prerequisiteTopics": prereq, + } + ) + except Exception as topic_err: + logger.warning(f"Topic extraction via AI failed: {topic_err}") + + if not extracted_topics: + file_warnings.append("AI topic extraction fallback was used.") + extracted_topics = _fallback_topic_extraction(extracted_text) + + file_topics = [ + { + **topic, + "sourceFiles": [filename], + } + for topic in extracted_topics + ] + + file_sections = [ + { + **section, + "sourceFile": filename, + } + for section in sections + ] + + persistence_result = _persist_course_material_artifact( + request, + file_hash=file_hash, + file_name=filename, + file_type=ext.replace(".", ""), + extracted_text=extracted_text, + sections=file_sections, + topics=file_topics, + warnings=file_warnings, + class_section_id=classSectionId, + class_name=className, + ) + if persistence_result.get("warning"): + file_warnings.append(str(persistence_result["warning"])) + material_id = persistence_result.get("materialId") + persisted = bool(persistence_result.get("persisted")) + source_legitimacy = cast(Dict[str, Any], persistence_result.get("sourceLegitimacy") or source_legitimacy) + + file_status = "success" if not file_warnings else "partial_success" + except HTTPException as file_exc: + file_status = "failed" + file_warnings.append(str(file_exc.detail)) + except Exception as file_exc: + logger.error(f"Course material processing failed for {filename}: {file_exc}") + file_status = "failed" + file_warnings.append(f"Unexpected processing error: {str(file_exc)}") + + file_result = { + "fileName": filename, + "fileType": ext.replace(".", ""), + "status": file_status, + "fileHash": file_hash, + "materialId": material_id, + "persisted": persisted, + "sourceLegitimacy": source_legitimacy, + "classSectionId": normalized_class_section_id, + "className": normalized_class_name, + "extractedTextLength": extracted_text_length, + "sections": file_sections, + "topics": file_topics, + "warnings": file_warnings, + } + per_file_results.append(file_result) + all_sections.extend(file_sections) + all_topics.extend(file_topics) + all_warnings.extend([f"{filename}: {warning}" for warning in file_warnings]) + + first_successful = next( + (f for f in per_file_results if f.get("status") in {"success", "partial_success"}), + None, + ) + successful_files = sum(1 for f in per_file_results if f.get("status") in {"success", "partial_success"}) + failed_files = len(per_file_results) - successful_files + total_extracted_text_length = sum(int(f.get("extractedTextLength", 0) or 0) for f in per_file_results) + + response_payload = { + "success": successful_files > 0, + "fileName": (first_successful or {}).get("fileName", ""), + "fileType": (first_successful or {}).get("fileType", ""), + "fileHash": (first_successful or {}).get("fileHash"), + "materialId": (first_successful or {}).get("materialId"), + "persisted": bool(first_successful and first_successful.get("persisted")), + "classSectionId": normalized_class_section_id, + "className": normalized_class_name, + "extractedTextLength": total_extracted_text_length, + "sections": all_sections, + "topics": all_topics, + "warnings": all_warnings, + "files": per_file_results, + "summary": { + "totalFiles": len(per_file_results), + "successfulFiles": successful_files, + "failedFiles": failed_files, + }, + } + + _write_access_audit_log( + request, + action="course_material_upload", + status="success" if successful_files > 0 else "failure", + class_section_id=normalized_class_section_id, + metadata={ + "totalFiles": len(per_file_results), + "successfulFiles": successful_files, + "failedFiles": failed_files, + "totalTopics": len(all_topics), + "persisted": bool(first_successful and first_successful.get("persisted")), + }, + ) + + return response_payload + + except HTTPException: + raise + except Exception as e: + logger.error(f"Course material upload error: {e}") + raise HTTPException(status_code=500, detail=f"Course material upload error: {str(e)}") + + +@app.get("/api/upload/course-materials/recent") +async def get_recent_course_materials( + request: Request, + classSectionId: Optional[str] = Query(default=None), + limit: int = Query(default=10, ge=1, le=50), +): + """List recent uploaded course materials for the authenticated teacher/admin.""" + try: + user = get_current_user(request) + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + normalized_class_section_id = (classSectionId or "").strip() or None + + query = ( + firebase_firestore.client() + .collection("courseMaterials") + .where("teacherId", "==", user.uid) + ) + if normalized_class_section_id: + query = query.where("classSectionId", "==", normalized_class_section_id) + + warnings: List[str] = [] + try: + docs = ( + query + .order_by("updatedAt", direction=FIRESTORE_QUERY_DESCENDING) + .limit(limit) + .stream() + ) + except Exception: + warnings.append("Course-material lookup used fallback query path without ordering.") + docs = query.limit(limit).stream() + + materials: List[Dict[str, Any]] = [] + expired_count = 0 + for doc in docs: + data = doc.to_dict() or {} + if _is_artifact_expired(data): + expired_count += 1 + continue + + topics = data.get("topics") or [] + created_at = data.get("createdAt") + updated_at = data.get("updatedAt") + created_at_iso = created_at.isoformat() if created_at is not None and hasattr(created_at, "isoformat") else None + updated_at_iso = updated_at.isoformat() if updated_at is not None and hasattr(updated_at, "isoformat") else None + + materials.append( + { + "materialId": data.get("materialId") or doc.id, + "fileName": data.get("fileName", ""), + "fileType": data.get("fileType", ""), + "classSectionId": data.get("classSectionId"), + "className": data.get("className"), + "topicsCount": len(topics), + "topicTitles": [str(t.get("title", "")).strip() for t in topics[:5] if str(t.get("title", "")).strip()], + "extractedTextLength": int(data.get("extractedTextLength", 0) or 0), + "retentionDays": int(data.get("retentionDays", IMPORT_RETENTION_DAYS) or IMPORT_RETENTION_DAYS), + "expiresAtEpoch": data.get("expiresAtEpoch"), + "createdAt": created_at_iso, + "updatedAt": updated_at_iso, + } + ) + + if expired_count > 0: + warnings.append(f"{expired_count} expired course-material artifact(s) were excluded by retention policy.") + + response_payload = { + "success": True, + "classSectionId": normalized_class_section_id, + "materials": materials, + "warnings": warnings, + } + + _write_access_audit_log( + request, + action="course_material_recent_read", + status="success", + class_section_id=normalized_class_section_id, + metadata={ + "requestedClassSectionId": normalized_class_section_id, + "requestedLimit": limit, + "returnedMaterials": len(materials), + "expiredExcluded": expired_count, + "warningsCount": len(warnings), + }, + ) + + return response_payload + + except HTTPException: + raise + except Exception as e: + logger.error(f"Recent course materials lookup error: {e}") + raise HTTPException(status_code=500, detail=f"Recent materials lookup error: {str(e)}") + + +@app.get("/api/course-materials/topics") +async def get_course_material_topics( + request: Request, + classSectionId: Optional[str] = Query(default=None), + materialId: Optional[str] = Query(default=None), + limit: int = Query(default=20, ge=1, le=50), +): + """Return persisted course-material topic map for the authenticated teacher/admin.""" + try: + normalized_class_section_id = (classSectionId or "").strip() or None + payload = _load_persisted_course_material_topics( + request, + class_section_id=normalized_class_section_id, + material_id=materialId, + limit_materials=limit, + ) + response_payload = { + "success": True, + "classSectionId": normalized_class_section_id, + "materialId": (materialId or "").strip() or None, + "topics": payload.get("topics", []), + "materials": payload.get("materials", []), + "warnings": payload.get("warnings", []), + } + + _write_access_audit_log( + request, + action="course_material_topics_read", + status="success", + class_section_id=normalized_class_section_id, + metadata={ + "limit": limit, + "materialId": (materialId or "").strip() or None, + "topicsReturned": len(payload.get("topics", [])), + "materialsReturned": len(payload.get("materials", [])), + "warningsCount": len(payload.get("warnings", [])), + }, + ) + + return response_payload + except HTTPException: + raise + except Exception as e: + logger.error(f"Course material topics lookup error: {e}") + raise HTTPException(status_code=500, detail=f"Course materials topics lookup error: {str(e)}") + + +# ─── Quiz Maker Models ──────────────────────────────────────── + +VALID_QUESTION_TYPES = [ + "identification", + "enumeration", + "multiple_choice", + "word_problem", + "equation_based", +] + +VALID_BLOOM_LEVELS = ["remember", "understand", "apply", "analyze"] + +VALID_DIFFICULTY_LEVELS = ["easy", "medium", "hard"] + +# ── Quiz generation hard limits ──────────────────────────────── +# Moderate classroom profile: supports longer quizzes while keeping +# generation latency and payload size manageable across providers. +MAX_QUESTIONS_LIMIT = 30 +MAX_TOPICS_LIMIT = 12 + + +class QuizGenerationRequest(BaseModel): + topics: List[str] = Field(..., min_length=1, description="Specific math topics to cover") + gradeLevel: str = Field(..., description="Student grade level (e.g., 'Grade 7', 'Grade 10', 'College')") + numQuestions: int = Field(default=10, ge=1, le=MAX_QUESTIONS_LIMIT, description="Number of questions to generate (max 30)") + questionTypes: List[str] = Field( + default=["multiple_choice", "identification", "word_problem"], + description="Types of questions to include", + ) + includeGraphs: bool = Field(default=False, description="Include graph-based identification questions") + difficultyDistribution: Dict[str, int] = Field( + default={"easy": 30, "medium": 50, "hard": 20}, + description="Percentage distribution per difficulty level", + ) + bloomLevels: List[str] = Field( + default=["remember", "understand", "apply", "analyze"], + description="Bloom's Taxonomy cognitive levels", + ) + excludeTopics: List[str] = Field( + default_factory=list, + description="Topics the class is already competent in — these will be excluded", + ) + classSectionId: Optional[str] = Field(default=None, description="Optional class section context for imported topics") + className: Optional[str] = Field(default=None, description="Optional class name context for metadata") + materialId: Optional[str] = Field(default=None, description="Optional specific course-material artifact ID") + preferImportedTopics: bool = Field( + default=True, + description="When true, prioritise persisted imported topics for generation when available", + ) + + @field_validator("questionTypes") + @classmethod + def validate_question_types(_cls, values: List[str]) -> List[str]: + for value in values: + if value not in VALID_QUESTION_TYPES: + raise ValueError(f"Invalid question type '{value}'. Must be one of: {VALID_QUESTION_TYPES}") + return values + + @field_validator("bloomLevels") + @classmethod + def validate_bloom_levels(_cls, values: List[str]) -> List[str]: + for value in values: + if value not in VALID_BLOOM_LEVELS: + raise ValueError(f"Invalid Bloom level '{value}'. Must be one of: {VALID_BLOOM_LEVELS}") + return values + + @field_validator("difficultyDistribution") + @classmethod + def validate_difficulty_distribution(_cls, v: Dict[str, int]) -> Dict[str, int]: + for key in v: + if key not in VALID_DIFFICULTY_LEVELS: + raise ValueError(f"Invalid difficulty key '{key}'. Must be one of: {VALID_DIFFICULTY_LEVELS}") + total = sum(v.values()) + if total != 100: + raise ValueError(f"Difficulty distribution percentages must sum to 100, got {total}") + return v + + +class QuizQuestion(BaseModel): + questionType: str + question: str + correctAnswer: str + options: Optional[List[str]] = None + bloomLevel: str + difficulty: str + topic: str + points: int + explanation: str + provenance: Optional[Dict[str, Optional[str]]] = None + + +class QuizResponse(BaseModel): + questions: List[QuizQuestion] + totalPoints: int + metadata: Dict[str, Any] + + +class StudentCompetencyRequest(BaseModel): + studentId: str = Field(..., description="Firebase user ID of the student") + quizHistory: Optional[List[Dict[str, Any]]] = Field( + default_factory=list, + description="Student quiz history — list of {topic, score, total, timeTaken}", + ) + + +class TopicCompetency(BaseModel): + topic: str + efficiencyScore: float = Field(..., ge=0, le=100) + competencyLevel: str + perspective: str + + +class StudentCompetencyResponse(BaseModel): + studentId: str + competencies: List[TopicCompetency] + recommendedTopics: List[str] + excludeTopics: List[str] + + +class CalculatorRequest(BaseModel): + expression: str = Field(..., min_length=1, max_length=500, description="Mathematical expression to evaluate") + + +class CalculatorResponse(BaseModel): + expression: str + result: str + steps: List[str] + simplified: Optional[str] = None + latex: Optional[str] = None + + +class LessonGenerationRequest(BaseModel): + gradeLevel: str = Field(..., description="Grade level context for lesson generation") + subject: Optional[str] = Field(default="general_math", description="Curriculum subject identifier") + quarter: Optional[int] = Field(default=1, ge=1, le=4, description="Curriculum quarter") + moduleUnit: Optional[str] = Field(default=None, description="Optional module or unit context") + lessonTitle: Optional[str] = Field(default=None, description="Optional lesson title") + learningCompetency: Optional[str] = Field(default=None, description="Optional learning competency text") + learnerLevel: Optional[str] = Field(default=None, description="Optional learner level context") + classSectionId: Optional[str] = Field(default=None, description="Optional class section context") + className: Optional[str] = Field(default=None, description="Optional class display name") + materialId: Optional[str] = Field(default=None, description="Optional specific course-material artifact ID") + focusTopics: List[str] = Field(default_factory=list, description="Optional explicit topic overrides") + topicCount: int = Field(default=5, ge=1, le=10, description="Maximum number of focus topics") + preferImportedTopics: bool = Field(default=True, description="Prefer persisted imported topics when available") + allowReviewSources: bool = Field(default=False, description="Allow generation from review_required sources") + allowUnverifiedLesson: bool = Field(default=False, description="Allow returning lessons that fail self-validation") + curriculumContext: Optional[str] = Field(default=None, description="Optional RAG curriculum context block") + + +class CurriculumEvidenceSource(BaseModel): + subject: str + quarter: int + content: str + sourceFile: Optional[str] = None + page: Optional[int] = None + score: float = 0.0 + contentDomain: Optional[str] = None + chunkType: Optional[str] = None + + +class CurriculumGroundingSummary(BaseModel): + query: str + confidence: float + confidenceBand: str + retrievedChunks: int + needsReview: bool + issues: List[str] = Field(default_factory=list) + + +class GroundedWorkedExample(BaseModel): + problem: str + solution: str + + +class LessonPlanBlock(BaseModel): + blockId: str + title: str + objective: str + strategy: str + estimatedMinutes: int + activities: List[str] + checksForUnderstanding: List[str] + remediationTips: List[str] + provenance: Optional[Dict[str, Optional[str]]] = None + + +class SourceLegitimacyReport(BaseModel): + status: str + score: float + verifiedMaterials: int + reviewMaterials: int + rejectedMaterials: int + evidenceChecked: List[str] + issues: List[str] + + +class LessonSelfValidationReport(BaseModel): + passed: bool + score: float + issues: List[str] + checks: Dict[str, Any] + + +class LessonPlanResponse(BaseModel): + success: bool + lessonTitle: str + curriculumCompetency: Optional[str] = None + lessonObjective: Optional[str] = None + realWorldHook: Optional[str] = None + explanation: Optional[str] = None + workedExample: Optional[GroundedWorkedExample] = None + guidedPractice: List[str] = Field(default_factory=list) + independentPractice: List[str] = Field(default_factory=list) + quickAssessment: List[str] = Field(default_factory=list) + reflectionPrompt: Optional[str] = None + sourceCitations: List[str] = Field(default_factory=list) + retrievedEvidence: List[CurriculumEvidenceSource] = Field(default_factory=list) + curriculumGrounding: CurriculumGroundingSummary + gradeLevel: str + classSectionId: Optional[str] = None + className: Optional[str] = None + subject: Optional[str] = None + quarter: Optional[int] = None + moduleUnit: Optional[str] = None + learnerLevel: Optional[str] = None + usedImportedTopics: bool + importedTopicCount: int + weakSignals: Dict[str, float] + focusTopics: List[str] + blocks: List[LessonPlanBlock] + provenanceSummary: List[Dict[str, Optional[str]]] + sourceLegitimacy: SourceLegitimacyReport + selfValidation: LessonSelfValidationReport + publishReady: bool + needsReview: bool = False + reviewReason: Optional[str] = None + warnings: List[str] + + +class AsyncTaskSubmitResponse(BaseModel): + success: bool + taskId: str + status: str + taskKind: str + createdAt: str + + +class AsyncTaskStatusResponse(BaseModel): + success: bool + taskId: str + taskKind: str + status: str + createdAt: str + startedAt: Optional[str] = None + completedAt: Optional[str] = None + progressPercent: float = 0.0 + progressStage: str = "queued" + progressMessage: Optional[str] = None + result: Optional[Dict[str, Any]] = None + error: Optional[Any] = None + + +class AsyncTaskListResponse(BaseModel): + success: bool + count: int + tasks: List[AsyncTaskStatusResponse] + + +class AsyncTaskCancelResponse(BaseModel): + success: bool + taskId: str + status: str + message: str + + +class InferenceMetricsResponse(BaseModel): + success: bool + metrics: Dict[str, Any] + + +class HFMonitoringDataResponse(BaseModel): + success: bool + data: Dict[str, Any] + + +class ImportGroundedFeedbackRequest(BaseModel): + flow: str = Field(..., description="Flow identifier: quiz or lesson") + status: str = Field(..., description="Event status: success, failed, or skipped") + classSectionId: Optional[str] = Field(default=None, description="Optional class section context") + className: Optional[str] = Field(default=None, description="Optional class display name") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Optional event metadata") + + @field_validator("flow") + @classmethod + def validate_flow(_cls, v: str) -> str: + value = (v or "").strip().lower() + if value not in {"quiz", "lesson"}: + raise ValueError("flow must be one of: quiz, lesson") + return value + + @field_validator("status") + @classmethod + def validate_status(_cls, v: str) -> str: + value = (v or "").strip().lower() + if value not in {"success", "failed", "skipped"}: + raise ValueError("status must be one of: success, failed, skipped") + return value + + +class ImportGroundedFeedbackResponse(BaseModel): + success: bool + stored: bool + warnings: List[str] + + +class ImportGroundedHourlyVolumeItem(BaseModel): + hourBucket: str + flow: str + status: str + eventCount: int + + +class ImportGroundedClassRateItem(BaseModel): + classSectionId: str + total24h: int + failed24h: int + skipped24h: int + failureRate24h: float + skippedRate24h: float + total7d: int + failed7d: int + skipped7d: int + failureRate7d: float + skippedRate7d: float + + +class ImportGroundedFlowUsageItem(BaseModel): + flow: str + totalEvents: int + eligibleEvents: int + groundedEvents: int + groundedUsageRatio: float + + +class ImportGroundedErrorReasonItem(BaseModel): + normalizedErrorReason: str + occurrences: int + + +class ImportGroundedTelemetryThresholds(BaseModel): + go: bool + reasons: List[str] + + +class ImportGroundedTelemetrySummaryResponse(BaseModel): + success: bool + classSectionId: Optional[str] = None + lookbackDays: int + totalEvents: int + hourlyVolume: List[ImportGroundedHourlyVolumeItem] + classRates: List[ImportGroundedClassRateItem] + flowUsage: List[ImportGroundedFlowUsageItem] + topErrors: List[ImportGroundedErrorReasonItem] + thresholds: ImportGroundedTelemetryThresholds + warnings: List[str] + + +class ImportGroundedAccessAuditItem(BaseModel): + auditId: str + action: str + status: str + path: str + method: str + classSectionId: Optional[str] = None + createdAtIso: Optional[str] = None + metadata: Dict[str, Any] + + +class ImportGroundedAccessAuditSummary(BaseModel): + totalEvents: int + byAction: Dict[str, int] + byStatus: Dict[str, int] + + +class ImportGroundedAccessAuditResponse(BaseModel): + success: bool + classSectionId: Optional[str] = None + lookbackDays: int + entries: List[ImportGroundedAccessAuditItem] + summary: ImportGroundedAccessAuditSummary + warnings: List[str] + + +# ─── Diagnostic Test Models ──────────────────────────────────── + +class DiagnosticGenerateRequest(BaseModel): + strand: str = Field(..., description="Student strand: ABM, STEM, HUMSS, GAS, TVL") + gradeLevel: str = Field(..., description="Grade level: Grade 11 or Grade 12") + numQuestions: int = Field(default=15, ge=5, le=30, description="Number of questions to generate") + + +class DiagnosticQuestion(BaseModel): + question_id: str + competency_code: str + domain: str + topic: str + difficulty: str + bloom_level: str + question_text: str + options: Dict[str, str] + correct_answer: str + solution_hint: str + curriculum_reference: str + + +class DiagnosticGenerateResponse(BaseModel): + questions: List[DiagnosticQuestion] + test_id: str + metadata: Dict[str, Any] + + +class DiagnosticSubmitRequest(BaseModel): + user_id: str + test_id: str + strand: str + grade_level: str + responses: List[Dict[str, Any]] + + +class DiagnosticResult(BaseModel): + user_id: str + test_id: str + taken_at: datetime + strand: str + grade_level: str + total_items: int + total_score: int + percentage_score: float + responses: List[Dict[str, Any]] + domain_scores: Dict[str, Dict[str, Any]] + risk_profile: Dict[str, Any] + + +class DiagnosticSubmitResponse(BaseModel): + success: bool + result: DiagnosticResult + risk_profile: Dict[str, Any] + domain_scores: Dict[str, Dict[str, Any]] + redirect_to: str + + +class DiagnosticResultsResponse(BaseModel): + success: bool + results: List[DiagnosticResult] + + +# ─── DepEd Curriculum Competency Domains ──────────────────────────── + +DEPD_ED_COMPETENCY_DOMAINS: Dict[str, Dict[str, List[str]]] = { + "ABM": { + "Grade 11": [ + "Business Mathematics - Fractions, Decimals, Percent", + "Business Mathematics - Proportion", + "Business Mathematics - Markup and Margin", + "Business Mathematics - Trade Discounts and VAT", + "Business Mathematics - Commissions", + "Business Mathematics - Salaries and Wages", + "Business Mathematics - Mandatory Deductions", + "Business Mathematics - Employee Benefits", + "Business Mathematics - Overtime Pay", + "Business Mathematics - Simple Interest", + "Business Mathematics - Compound Interest", + "Business Mathematics - Loans and Credit", + "Business Mathematics - Data Presentation", + ], + "Grade 12": [ + "Business Mathematics - Business Reports", + "Business Mathematics - Financial Analysis", + "Business Mathematics - Investment Decisions", + "Business Mathematics - Taxation", + "Business Mathematics - Asset Depreciation", + ], + }, + "STEM": { + "Grade 11": [ + "General Mathematics - Patterns and Sequences", + "General Mathematics - Functions", + "General Mathematics - Function Operations", + "General Mathematics - Inverse Functions", + "General Mathematics - Unit Conversions", + "General Mathematics - Geometry", + "General Mathematics - Trigonometry", + "Statistics - Data Organization", + "Statistics - Measures of Central Tendency", + "Statistics - Measures of Variability", + "Statistics - Random Variables", + "Statistics - Probability Distributions", + "Statistics - Normal Distribution", + "Statistics - Sampling", + "Statistics - Hypothesis Testing", + ], + "Grade 12": [ + "General Mathematics - Financial Math", + "General Mathematics - Compound Interest", + "General Mathematics - Annuities", + "General Mathematics - Amortization", + "General Mathematics - Logical Propositions", + "Statistics - Confidence Intervals", + "Statistics - Correlation", + "Statistics - Regression", + ], + }, + "HUMSS": { + "Grade 11": [ + "General Mathematics - Patterns and Sequences", + "General Mathematics - Functions", + "General Mathematics - Statistics Basics", + "General Mathematics - Data Analysis", + "General Mathematics - Probability", + ], + "Grade 12": [ + "General Mathematics - Financial Math", + "General Mathematics - Logical Reasoning", + "Statistics - Statistical Inference", + ], + }, + "GAS": { + "Grade 11": [ + "General Mathematics - Patterns and Sequences", + "General Mathematics - Functions", + "General Mathematics - Statistics Basics", + ], + "Grade 12": [ + "General Mathematics - Financial Math", + "General Mathematics - Logical Reasoning", + ], + }, + "TVL": { + "Grade 11": [ + "Applied Mathematics - Number Sense", + "Applied Mathematics - Measurement", + "Applied Mathematics - Data Interpretation", + "Applied Mathematics - Problem Solving", + ], + "Grade 12": [ + "Applied Mathematics - Business Math", + "Applied Mathematics - Consumer Math", + "Applied Mathematics - Technical Math", + ], + }, +} + + +def _coerce_event_timestamp_utc(event: Dict[str, Any]) -> Optional[datetime]: + created_at = event.get("createdAt") + if isinstance(created_at, datetime): + return created_at if created_at.tzinfo else created_at.replace(tzinfo=timezone.utc) + + created_at_iso = str(event.get("createdAtIso") or "").strip() + if not created_at_iso: + return None + + try: + parsed = datetime.fromisoformat(created_at_iso.replace("Z", "+00:00")) + return parsed if parsed.tzinfo else parsed.replace(tzinfo=timezone.utc) + except Exception: + return None + + +def _to_compact_json(value: Any) -> str: + try: + return json.dumps(value, separators=(",", ":"), ensure_ascii=True) + except Exception: + return "{}" + + +def _csv_escape(value: Any) -> str: + text = str(value if value is not None else "") + return '"' + text.replace('"', '""') + '"' + + +# ─── Quiz Topics Database (SHS Grade 11-12 Only) ───────────── + +MATH_TOPICS_BY_GRADE: Dict[str, Dict[str, List[str]]] = { + "Grade 11": { + "General Mathematics - Patterns, Relations, and Functions": [ + "Patterns and Real-Life Relationships", "Functions as Mathematical Models", + "Function Notation and Evaluation", "Domain and Range of Functions", + "Operations on Functions", "Composite Functions", "Inverse Functions", + "Graphs of Rational Functions", "Graphs of Exponential Functions", + "Graphs of Logarithmic Functions", + ], + "General Mathematics - Financial Mathematics": [ + "Simple and Compound Interest", "Simple and General Annuities", + "Present and Future Value", "Loans, Amortization, and Sinking Funds", + "Stocks, Bonds, and Market Indices", + "Business Decision-Making with Mathematical Models", + ], + "General Mathematics - Logic and Mathematical Reasoning": [ + "Propositions and Logical Connectives", "Truth Values and Truth Tables", + "Logical Equivalence and Implication", "Quantifiers and Negation", + "Validity of Arguments", + ], + "Statistics and Probability - Random Variables": [ + "Random Variables", "Discrete Probability Distributions", + "Mean and Variance of Discrete RV", + ], + "Statistics and Probability - Normal Distribution": [ + "Normal Distribution", "Standard Normal Distribution and Z-scores", + "Areas Under the Normal Curve", + ], + "Statistics and Probability - Sampling and Estimation": [ + "Sampling Distributions", "Central Limit Theorem", + "Point Estimation", "Confidence Intervals", + ], + "Statistics and Probability - Hypothesis Testing": [ + "Hypothesis Testing Concepts", "T-test", "Z-test", + "Correlation and Regression", + ], + }, + "Grade 12": { + "Pre-Calculus - Analytic Geometry": [ + "Conic Sections - Parabola", "Conic Sections - Ellipse", + "Conic Sections - Hyperbola", "Conic Sections - Circle", + "Systems of Nonlinear Equations", + ], + "Pre-Calculus - Series and Induction": [ + "Sequences and Series", "Arithmetic Sequences", "Geometric Sequences", + "Mathematical Induction", "Binomial Theorem", + ], + "Pre-Calculus - Trigonometry": [ + "Angles and Unit Circle", "Trigonometric Functions", + "Trigonometric Identities", "Sum and Difference Formulas", + "Inverse Trigonometric Functions", "Polar Coordinates", + ], + "Basic Calculus - Limits": [ + "Limits of Functions", "Limit Theorems", "One-Sided Limits", + "Infinite Limits and Limits at Infinity", "Continuity of Functions", + ], + "Basic Calculus - Derivatives": [ + "Definition of the Derivative", "Differentiation Rules", "Chain Rule", + "Implicit Differentiation", "Higher-Order Derivatives", "Related Rates", + "Extrema and the First Derivative Test", + "Concavity and the Second Derivative Test", "Optimization Problems", + ], + "Basic Calculus - Integration": [ + "Antiderivatives and Indefinite Integrals", + "Definite Integrals and the FTC", + "Integration by Substitution", "Area Under a Curve", + ], + }, +} + + +def _normalize_topic_key(value: str) -> str: + key = re.sub(r"[^a-z0-9\s]+", " ", (value or "").lower()) + key = re.sub(r"\s+", " ", key).strip() + return key + + +TOPIC_LABEL_ALIASES: Dict[str, str] = { + # Legacy General Mathematics aliases mapped to strengthened SHS canonical labels. + _normalize_topic_key("Functions and Relations"): "Functions as Mathematical Models", + _normalize_topic_key("Evaluating Functions"): "Function Notation and Evaluation", + _normalize_topic_key("Rational Functions"): "Graphs of Rational Functions", + _normalize_topic_key("Exponential Functions"): "Graphs of Exponential Functions", + _normalize_topic_key("Logarithmic Functions"): "Graphs of Logarithmic Functions", + _normalize_topic_key("Simple Interest"): "Simple and Compound Interest", + _normalize_topic_key("Compound Interest"): "Simple and Compound Interest", + _normalize_topic_key("Annuities"): "Simple and General Annuities", + _normalize_topic_key("Loans and Amortization"): "Loans, Amortization, and Sinking Funds", + _normalize_topic_key("Stocks and Bonds"): "Stocks, Bonds, and Market Indices", + _normalize_topic_key("Propositions and Connectives"): "Propositions and Logical Connectives", + _normalize_topic_key("Truth Tables"): "Truth Values and Truth Tables", + _normalize_topic_key("Logical Equivalence"): "Logical Equivalence and Implication", + _normalize_topic_key("Valid Arguments and Fallacies"): "Validity of Arguments", +} + + +def _canonicalize_topic_label(value: str) -> str: + clean_value = str(value or "").strip() + if not clean_value: + return "" + return TOPIC_LABEL_ALIASES.get(_normalize_topic_key(clean_value), clean_value) + + +def _canonicalize_topic_list(values: List[str]) -> List[str]: + canonical: List[str] = [] + for value in values: + normalized = _canonicalize_topic_label(value) + if normalized and normalized not in canonical: + canonical.append(normalized) + return canonical + + +def _resolve_grade_level_key(grade_level: Optional[str]) -> Optional[str]: + raw = str(grade_level or "").strip() + if not raw: + return None + + normalized = raw.lower() + if normalized in {"11", "grade11", "grade 11", "g11"}: + return "Grade 11" + if normalized in {"12", "grade12", "grade 12", "g12"}: + return "Grade 12" + + for key in MATH_TOPICS_BY_GRADE.keys(): + if key.lower() == normalized: + return key + + return None + + +def _fallback_topics_for_grade(grade_level: str, topic_count: int) -> List[str]: + fallback_topics: List[str] = [] + grade_key = _resolve_grade_level_key(grade_level) + if grade_key: + for _, topics in MATH_TOPICS_BY_GRADE[grade_key].items(): + for topic in topics: + if topic not in fallback_topics: + fallback_topics.append(topic) + if len(fallback_topics) >= topic_count: + return fallback_topics + + # Keep grade-level separation strict; if grade is unknown, default to Grade 11. + default_grade = "Grade 11" + for _, topics in MATH_TOPICS_BY_GRADE[default_grade].items(): + for topic in topics: + if topic not in fallback_topics: + fallback_topics.append(topic) + if len(fallback_topics) >= topic_count: + return fallback_topics + return fallback_topics + + +def _load_class_performance_artifacts( + request: Request, + *, + class_section_id: Optional[str] = None, + max_records: int = 500, +) -> Dict[str, float]: + if not (_firebase_ready and firebase_firestore): + return { + "recordsCount": 0, + "averageQuizScore": 0.0, + "averageAttendance": 0.0, + "averageEngagement": 0.0, + "averageAssignmentCompletion": 0.0, + "atRiskRate": 0.0, + } + + user = get_current_user(request) + normalized_class_section_id = (class_section_id or "").strip() or None + + query = ( + firebase_firestore.client() + .collection("normalizedClassRecords") + .where("teacherId", "==", user.uid) + ) + if normalized_class_section_id: + query = query.where("classSectionId", "==", normalized_class_section_id) + + docs = query.limit(max_records).stream() + scores: List[float] = [] + attendance_rates: List[float] = [] + engagement_rates: List[float] = [] + completion_rates: List[float] = [] + at_risk_count = 0 + + for doc in docs: + row = doc.to_dict() or {} + score = float(row.get("avgQuizScore") or 0.0) + attendance = float(row.get("attendance") or 0.0) + engagement = float(row.get("engagementScore") or 0.0) + completion = float(row.get("assignmentCompletion") or 0.0) + + scores.append(score) + attendance_rates.append(attendance) + engagement_rates.append(engagement) + completion_rates.append(completion) + + if score < 60 or engagement < 55: + at_risk_count += 1 + + count = len(scores) + if count == 0: + return { + "recordsCount": 0, + "averageQuizScore": 0.0, + "averageAttendance": 0.0, + "averageEngagement": 0.0, + "averageAssignmentCompletion": 0.0, + "atRiskRate": 0.0, + } + + return { + "recordsCount": float(count), + "averageQuizScore": sum(scores) / count, + "averageAttendance": sum(attendance_rates) / count, + "averageEngagement": sum(engagement_rates) / count, + "averageAssignmentCompletion": sum(completion_rates) / count, + "atRiskRate": at_risk_count / count, + } + + +def _parse_lesson_plan_json(raw: str) -> Dict[str, Any]: + cleaned = (raw or "").strip() + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + cleaned = cleaned.strip() + + start = cleaned.find("{") + end = cleaned.rfind("}") + 1 + if start >= 0 and end > start: + try: + parsed = json.loads(cleaned[start:end]) + if isinstance(parsed, dict): + return parsed + except Exception: + pass + return {} + + +def _deterministic_lesson_checks( + *, + selected_topics: List[str], + blocks: List[LessonPlanBlock], +) -> Dict[str, Any]: + issues: List[str] = [] + selected_topic_keys = {_normalize_topic_key(topic) for topic in selected_topics if topic.strip()} + covered_topic_keys: Set[str] = set() + + if len(blocks) < 3: + issues.append("Lesson must include at least 3 instructional blocks.") + + for block in blocks: + block_text = " ".join([ + block.title, + block.objective, + block.strategy, + " ".join(block.activities), + " ".join(block.checksForUnderstanding), + " ".join(block.remediationTips), + ]).lower() + + matched = False + for topic in selected_topics: + topic_key = _normalize_topic_key(topic) + topic_words = [w for w in topic_key.split(" ") if w] + if topic_words and all(word in block_text for word in topic_words[: min(2, len(topic_words))]): + covered_topic_keys.add(topic_key) + matched = True + break + if not matched: + issues.append(f"Block '{block.title}' is weakly grounded to selected topics.") + + if not block.activities: + issues.append(f"Block '{block.title}' is missing classroom activities.") + if not block.checksForUnderstanding: + issues.append(f"Block '{block.title}' is missing checks for understanding.") + + topic_coverage_ratio = 1.0 + if selected_topic_keys: + topic_coverage_ratio = len(covered_topic_keys) / max(1, len(selected_topic_keys)) + if topic_coverage_ratio < 0.6: + issues.append( + f"Topic coverage too low ({topic_coverage_ratio:.2f}); expected at least 0.60 across selected topics." + ) + + structure_ok = len(blocks) >= 3 and all(block.estimatedMinutes >= 5 for block in blocks) + grounding_ok = topic_coverage_ratio >= 0.6 + score = max(0.0, min(1.0, 1.0 - (0.12 * len(issues)))) + + return { + "score": round(score, 3), + "issues": issues, + "checks": { + "structure": structure_ok, + "topicGrounding": grounding_ok, + "topicCoverageRatio": round(topic_coverage_ratio, 3), + "blockCount": len(blocks), + }, + } + + +async def _ai_validate_lesson_plan( + *, + lesson_title: str, + selected_topics: List[str], + blocks: List[LessonPlanBlock], +) -> Dict[str, Any]: + compact_blocks = [ + { + "title": block.title, + "objective": block.objective, + "strategy": block.strategy, + "estimatedMinutes": block.estimatedMinutes, + "activities": block.activities, + "checksForUnderstanding": block.checksForUnderstanding, + "remediationTips": block.remediationTips, + } + for block in blocks + ] + + validation_prompt = ( + "Validate this generated math lesson plan for instructional quality and grounding. " + "Return JSON only in this schema: " + '{"passed":true|false,"score":0.0-1.0,"issues":["..."],"checks":{"mathSoundness":true|false,"topicGrounding":true|false,"classroomUsability":true|false}}. ' + "Fail the lesson if topics are hallucinated, objectives are vague, or classroom activities are not actionable.\n\n" + f"Lesson title: {lesson_title}\n" + f"Selected topics: {json.dumps(selected_topics)}\n" + f"Blocks: {json.dumps(compact_blocks)}" + ) + + try: + raw = await call_hf_chat_async( + messages=[ + { + "role": "system", + "content": "You are a strict lesson-quality verifier. Return valid JSON only.", + }, + {"role": "user", "content": validation_prompt}, + ], + task_type="lesson_generation", + max_tokens=420, + temperature=0.1, + top_p=0.9, + timeout=90, + ) + parsed = _parse_lesson_plan_json(raw) + if not parsed: + return { + "passed": False, + "score": 0.0, + "issues": ["AI validator returned invalid JSON."], + "checks": { + "mathSoundness": False, + "topicGrounding": False, + "classroomUsability": False, + }, + } + + score = float(parsed.get("score") or 0.0) + score = max(0.0, min(1.0, score)) + checks_raw: Dict[str, Any] = {} + if isinstance(parsed.get("checks"), dict): + checks_raw = cast(Dict[str, Any], parsed.get("checks")) + checks = { + "mathSoundness": bool(checks_raw.get("mathSoundness")), + "topicGrounding": bool(checks_raw.get("topicGrounding")), + "classroomUsability": bool(checks_raw.get("classroomUsability")), + } + issues = [str(item).strip() for item in (parsed.get("issues") or []) if str(item).strip()] + passed = bool(parsed.get("passed")) and score >= LESSON_VALIDATION_MIN_SCORE and all(checks.values()) + return { + "passed": passed, + "score": round(score, 3), + "issues": issues, + "checks": checks, + } + except Exception as validation_exc: + logger.warning(f"Lesson AI self-validation failed: {validation_exc}") + return { + "passed": False, + "score": 0.0, + "issues": ["AI self-validation failed due to runtime error."], + "checks": { + "mathSoundness": False, + "topicGrounding": False, + "classroomUsability": False, + }, + } + + +async def _validate_generated_lesson_plan( + *, + lesson_title: str, + selected_topics: List[str], + blocks: List[LessonPlanBlock], +) -> Dict[str, Any]: + deterministic = _deterministic_lesson_checks(selected_topics=selected_topics, blocks=blocks) + ai_validation = await _ai_validate_lesson_plan( + lesson_title=lesson_title, + selected_topics=selected_topics, + blocks=blocks, + ) + + issues = deterministic.get("issues", []) + ai_validation.get("issues", []) + checks = { + **deterministic.get("checks", {}), + **ai_validation.get("checks", {}), + } + + combined_score = round( + (0.4 * float(deterministic.get("score", 0.0))) + + (0.6 * float(ai_validation.get("score", 0.0))), + 3, + ) + passed = bool(ai_validation.get("passed")) and combined_score >= LESSON_VALIDATION_MIN_SCORE + + return { + "passed": passed, + "score": combined_score, + "issues": sorted(list({str(issue).strip() for issue in issues if str(issue).strip()})), + "checks": checks, + } + + +def _coerce_nonempty_str_list(raw_value: Any) -> List[str]: + if not isinstance(raw_value, list): + return [] + return [str(item).strip() for item in raw_value if str(item).strip()] + + +def _build_lesson_generation_content( + *, + lesson_payload: Dict[str, Any], + request: LessonGenerationRequest, + curriculum_competency: str, + lesson_title_hint: str, + retrieval_band: str, + retrieval_confidence: float, + source_legitimacy_report: Dict[str, Any], + curriculum_chunks: List[Dict[str, Any]], +) -> Dict[str, Any]: + lesson_title = str(lesson_payload.get("lessonTitle") or lesson_title_hint or "Intervention-Grounded Math Lesson Plan").strip() + curriculum_competency = str( + lesson_payload.get("curriculumCompetency") + or request.learningCompetency + or curriculum_competency + or lesson_title + ).strip() + lesson_objective = str( + lesson_payload.get("lessonObjective") + or f"Demonstrate {curriculum_competency} using DepEd curriculum evidence and Philippine real-life contexts." + ).strip() + real_world_hook = str( + lesson_payload.get("realWorldHook") + or "Connect the competency to practical decisions in work, business, finance, or daily life." + ).strip() + explanation = str( + lesson_payload.get("explanation") + or f"Use the retrieved curriculum evidence to explain {curriculum_competency} clearly and step by step." + ).strip() + reflection_prompt = str( + lesson_payload.get("reflectionPrompt") + or "How does this competency help you solve a real problem in school, work, or daily life?" + ).strip() + + raw_worked_example = lesson_payload.get("workedExample") if isinstance(lesson_payload, dict) else None + if isinstance(raw_worked_example, dict): + worked_example = GroundedWorkedExample( + problem=str(raw_worked_example.get("problem") or raw_worked_example.get("question") or "").strip() + or f"Worked example for {curriculum_competency}", + solution=str(raw_worked_example.get("solution") or raw_worked_example.get("answer") or "").strip() + or "Step-by-step solution grounded in the retrieved curriculum context.", + ) + else: + worked_example = GroundedWorkedExample( + problem=f"Worked example for {curriculum_competency}", + solution="Step-by-step solution grounded in the retrieved curriculum context.", + ) + + guided_practice = _coerce_nonempty_str_list(lesson_payload.get("guidedPractice") if isinstance(lesson_payload, dict) else None) + if not guided_practice: + guided_practice = [ + f"Solve a guided item on {curriculum_competency} using one cue from the retrieved curriculum evidence.", + "Compare your answer with a partner and justify each step.", + ] + + independent_practice = _coerce_nonempty_str_list(lesson_payload.get("independentPractice") if isinstance(lesson_payload, dict) else None) + if not independent_practice: + independent_practice = [ + f"Complete an independent task that applies {curriculum_competency} to a Philippine context.", + "Write a short justification of your answer using the curriculum language.", + ] + + quick_assessment = _coerce_nonempty_str_list(lesson_payload.get("quickAssessment") if isinstance(lesson_payload, dict) else None) + if not quick_assessment: + quick_assessment = [ + "One exit-ticket item that checks procedural accuracy.", + "One reflection question that checks real-world transfer.", + ] + + raw_source_citations = lesson_payload.get("sourceCitations") if isinstance(lesson_payload, dict) else [] + source_citations = _coerce_nonempty_str_list(raw_source_citations) + if not source_citations: + source_citations = [ + f"{chunk.get('source_file')} p.{chunk.get('page')} ({chunk.get('content_domain')}/{chunk.get('chunk_type')})" + for chunk in curriculum_chunks[:5] + if chunk.get("source_file") + ] + + needs_review = bool(lesson_payload.get("needsReview")) if isinstance(lesson_payload, dict) else False + needs_review = needs_review or retrieval_band == "low" + review_reason = str(lesson_payload.get("reviewReason") or "").strip() if isinstance(lesson_payload, dict) else "" + if not review_reason and retrieval_band == "low": + review_reason = "Curriculum retrieval confidence was low, so the lesson should be reviewed before classroom use." + if source_legitimacy_report.get("status") != "verified": + needs_review = True + if not review_reason: + review_reason = "One or more source checks require review." + + if needs_review: + review_reason = review_reason or "Lesson marked as needs review." + + retrieved_evidence: List[CurriculumEvidenceSource] = [] + for chunk in curriculum_chunks[:5]: + retrieved_evidence.append( + CurriculumEvidenceSource( + subject=str(chunk.get("subject") or request.subject or "general_math"), + quarter=int(chunk.get("quarter") or request.quarter or 1), + content=str(chunk.get("content") or "").strip(), + sourceFile=str(chunk.get("source_file") or "").strip() or None, + page=int(chunk.get("page") or 0) or None, + score=float(chunk.get("score") or 0.0), + contentDomain=str(chunk.get("content_domain") or "").strip() or None, + chunkType=str(chunk.get("chunk_type") or "").strip() or None, + ) + ) + + return { + "lessonTitle": lesson_title, + "curriculumCompetency": curriculum_competency, + "lessonObjective": lesson_objective, + "realWorldHook": real_world_hook, + "explanation": explanation, + "reflectionPrompt": reflection_prompt, + "workedExample": worked_example, + "guidedPractice": guided_practice, + "independentPractice": independent_practice, + "quickAssessment": quick_assessment, + "sourceCitations": source_citations, + "needsReview": needs_review, + "reviewReason": review_reason, + "retrievedEvidence": retrieved_evidence, + } + + +def _build_lesson_generation_blocks( + *, + lesson_title: str, + lesson_objective: str, + real_world_hook: str, + explanation: str, + worked_example: GroundedWorkedExample, + guided_practice: List[str], + independent_practice: List[str], + quick_assessment: List[str], + reflection_prompt: str, + retrieved_evidence: List[CurriculumEvidenceSource], + selected_topics: List[str], +) -> Dict[str, Any]: + provenance_summary: List[Dict[str, Optional[str]]] = [] + for evidence in retrieved_evidence: + summary_row = { + "topicId": None, + "title": evidence.content[:120] if evidence.content else lesson_title, + "materialId": evidence.sourceFile, + "sourceFile": evidence.sourceFile, + "sectionId": f"p.{evidence.page}" if evidence.page else None, + } + if summary_row not in provenance_summary: + provenance_summary.append(summary_row) + + if not provenance_summary: + provenance_summary = [ + { + "topicId": None, + "title": topic, + "materialId": None, + "sourceFile": None, + "sectionId": None, + } + for topic in selected_topics[:3] + ] + + first_provenance = provenance_summary[0] if provenance_summary else None + blocks: List[LessonPlanBlock] = [ + LessonPlanBlock( + blockId="block_1", + title="Real-World Hook", + objective=lesson_objective, + strategy="Context-setting discussion with retrieved curriculum evidence.", + estimatedMinutes=8, + activities=[real_world_hook], + checksForUnderstanding=["Ask students to identify the competency in the example."], + remediationTips=["Restate the context using simpler language if needed."], + provenance=first_provenance, + ), + LessonPlanBlock( + blockId="block_2", + title="Concept Explanation", + objective=lesson_objective, + strategy="Teacher-led explanation grounded in retrieved excerpts.", + estimatedMinutes=15, + activities=[explanation], + checksForUnderstanding=["Students paraphrase the key idea in their own words."], + remediationTips=["Break the explanation into smaller steps and repeat the terminology."], + provenance=first_provenance, + ), + LessonPlanBlock( + blockId="block_3", + title="Worked Example", + objective=worked_example.problem, + strategy="Model the solution path and annotate each step.", + estimatedMinutes=15, + activities=[f"Problem: {worked_example.problem}", f"Solution: {worked_example.solution}"], + checksForUnderstanding=["Students identify which step uses the retrieved competency."], + remediationTips=["Provide a partially completed solution scaffold if needed."], + provenance=first_provenance, + ), + LessonPlanBlock( + blockId="block_4", + title="Guided Practice", + objective="Apply the same method with scaffolded support.", + strategy="Teacher circulates while students complete guided items.", + estimatedMinutes=15, + activities=guided_practice, + checksForUnderstanding=["Check one correct answer and one justification."], + remediationTips=["Offer hints tied directly to the retrieved evidence."], + provenance=first_provenance, + ), + LessonPlanBlock( + blockId="block_5", + title="Independent Practice and Quick Check", + objective="Transfer the competency to a new but related context.", + strategy="Independent task followed by a short formative check.", + estimatedMinutes=15, + activities=independent_practice + quick_assessment, + checksForUnderstanding=["Collect an exit ticket and review misconceptions."], + remediationTips=["Revisit the retrieved source excerpt if students struggle."], + provenance=first_provenance, + ), + LessonPlanBlock( + blockId="block_6", + title="Reflection", + objective=reflection_prompt, + strategy="Metacognitive reflection and transfer discussion.", + estimatedMinutes=7, + activities=[reflection_prompt], + checksForUnderstanding=["Ask students to connect the skill to a real decision."], + remediationTips=["Prompt with a local example if reflection is shallow."], + provenance=first_provenance, + ), + ] + + return { + "provenanceSummary": provenance_summary, + "blocks": blocks, + } + + +@app.post("/api/lesson/generate", response_model=LessonPlanResponse) +async def generate_lesson_plan(http_request: Request, request: LessonGenerationRequest): + """ + Generate a class lesson plan grounded on imported course-material topics and + class performance artifacts. Falls back to built-in curriculum topics when + imported topics are unavailable. + """ + try: + enforce_rate_limit(http_request, "generate_lesson_plan", 20, 60) + + imported_topics_payload: Dict[str, Any] = {"topics": [], "materials": [], "warnings": []} + imported_topic_titles: List[str] = [] + warnings: List[str] = [] + import_grounding_enabled = ENABLE_IMPORT_GROUNDED_LESSON + + if not import_grounding_enabled and request.preferImportedTopics: + warnings.append( + "Import-grounded lesson generation is disabled by rollout flag; using focus topics and fallback curriculum." + ) + + if import_grounding_enabled and (request.preferImportedTopics or not request.focusTopics): + imported_topics_payload = _load_persisted_course_material_topics( + http_request, + class_section_id=request.classSectionId, + material_id=request.materialId, + limit_materials=20, + ) + imported_topic_titles = [ + str(topic.get("title") or "").strip() + for topic in (imported_topics_payload.get("topics") or []) + if str(topic.get("title") or "").strip() + ] + warnings.extend(imported_topics_payload.get("warnings") or []) + + selected_topics: List[str] = [] + for topic in request.focusTopics: + clean_topic = _canonicalize_topic_label(str(topic).strip()) + if clean_topic and clean_topic not in selected_topics: + selected_topics.append(clean_topic) + + if imported_topic_titles: + for topic in imported_topic_titles: + if topic not in selected_topics: + selected_topics.append(topic) + + if not selected_topics: + selected_topics = _fallback_topics_for_grade(request.gradeLevel, request.topicCount) + warnings.append("Using fallback curriculum topics because no imported topics were found.") + + selected_topics = selected_topics[: request.topicCount] + + requested_subject = (str(request.subject or "general_math").strip() or "general_math") + requested_quarter = int(request.quarter or 1) + competency_hint = str(request.learningCompetency or "").strip() or (selected_topics[0] if selected_topics else "") + lesson_title_hint = str(request.lessonTitle or "").strip() or competency_hint or "Grounded Math Lesson" + module_unit_hint = str(request.moduleUnit or "").strip() or None + learner_level_hint = str(request.learnerLevel or "").strip() or None + + retrieval_query = build_lesson_query( + competency_hint or lesson_title_hint, + requested_subject, + requested_quarter, + lesson_title=lesson_title_hint, + competency=competency_hint, + module_unit=module_unit_hint, + learner_level=learner_level_hint, + ) + + curriculum_chunks = retrieve_curriculum_context( + query=retrieval_query, + subject=requested_subject, + quarter=requested_quarter, + top_k=5, + ) + retrieval_summary = summarize_retrieval_confidence(curriculum_chunks) + retrieval_confidence = float(retrieval_summary.get("confidence") or 0.0) + retrieval_band = str(retrieval_summary.get("band") or "low") + retrieval_issues: List[str] = [] + if not curriculum_chunks: + retrieval_issues.append("No curriculum evidence was retrieved for the selected competency.") + elif retrieval_band == "low": + retrieval_issues.append("Retrieved evidence is weakly aligned; manual review recommended.") + + source_legitimacy_report = _evaluate_lesson_source_legitimacy( + imported_topics_payload, + allow_review_sources=request.allowReviewSources, + ) + using_imported_sources = bool(imported_topic_titles) + if not using_imported_sources: + source_legitimacy_report = { + "status": "verified", + "score": 1.0, + "verifiedMaterials": 0, + "reviewMaterials": 0, + "rejectedMaterials": 0, + "evidenceChecked": ["builtin_curriculum_fallback"], + "issues": [], + } + + if retrieval_issues: + source_legitimacy_report["status"] = "review_required" + source_legitimacy_report["score"] = min(float(source_legitimacy_report.get("score") or 0.0), retrieval_confidence or 0.5) + source_legitimacy_report.setdefault("issues", []) + source_legitimacy_report["issues"] = list({*map(str, source_legitimacy_report.get("issues") or []), *retrieval_issues}) + source_legitimacy_report.setdefault("evidenceChecked", []) + evidence_checked = list(source_legitimacy_report.get("evidenceChecked") or []) + evidence_checked.extend([ + f"{chunk.get('source_file')} p.{chunk.get('page')}" for chunk in curriculum_chunks if chunk.get("source_file") + ]) + source_legitimacy_report["evidenceChecked"] = sorted({str(item) for item in evidence_checked if str(item).strip()}) + + if ENFORCE_LEGIT_SOURCES_FOR_LESSONS and using_imported_sources: + source_status = str(source_legitimacy_report.get("status") or "review_required") + if source_status == "rejected": + raise HTTPException( + status_code=422, + detail={ + "message": "Imported source legitimacy checks failed. Lesson generation blocked.", + "sourceLegitimacy": source_legitimacy_report, + }, + ) + if source_status == "review_required" and not request.allowReviewSources: + raise HTTPException( + status_code=422, + detail={ + "message": "Imported sources require review. Set allowReviewSources=true to continue.", + "sourceLegitimacy": source_legitimacy_report, + }, + ) + + class_signals = _load_class_performance_artifacts( + http_request, + class_section_id=request.classSectionId, + max_records=500, + ) + + curriculum_context_block = "" + if request.curriculumContext and str(request.curriculumContext).strip(): + curriculum_context_block = f"{str(request.curriculumContext).strip()}\n\n" + + prompt = build_lesson_prompt( + lesson_title=lesson_title_hint, + competency=competency_hint or lesson_title_hint, + grade_level=request.gradeLevel, + subject=requested_subject, + quarter=requested_quarter, + learner_level=learner_level_hint, + module_unit=module_unit_hint, + curriculum_chunks=curriculum_chunks, + ) + + if curriculum_context_block: + prompt = f"{curriculum_context_block}{prompt}" + prompt = ( + f"{prompt}\n\n" + f"Class section: {request.classSectionId or 'n/a'}\n" + f"Class name: {request.className or 'n/a'}\n" + f"Performance signals: {json.dumps(class_signals)}\n" + f"Focus topics: {json.dumps(selected_topics)}\n" + "Use the retrieved curriculum evidence to ground the lesson and to generate explicit real-world applications." + ) + + lesson_payload: Dict[str, Any] = {} + try: + raw = await call_hf_chat_async( + messages=[ + { + "role": "system", + "content": "You are an expert math instructional designer. Output strict JSON only.", + }, + {"role": "user", "content": prompt}, + ], + max_tokens=1200, + temperature=0.25, + task_type="lesson_generation", + ) + lesson_payload = _parse_lesson_plan_json(raw) + except Exception as lesson_exc: + logger.warning(f"Lesson generation AI fallback engaged: {lesson_exc}") + warnings.append("AI lesson synthesis failed; generated deterministic scaffolded content from retrieved evidence.") + + lesson_core = _build_lesson_generation_content( + lesson_payload=lesson_payload, + request=request, + curriculum_competency=competency_hint, + lesson_title_hint=lesson_title_hint, + retrieval_band=retrieval_band, + retrieval_confidence=retrieval_confidence, + source_legitimacy_report=source_legitimacy_report, + curriculum_chunks=curriculum_chunks, + ) + lesson_title = str(lesson_core["lessonTitle"]) + curriculum_competency = str(lesson_core["curriculumCompetency"]) + lesson_objective = str(lesson_core["lessonObjective"]) + real_world_hook = str(lesson_core["realWorldHook"]) + explanation = str(lesson_core["explanation"]) + reflection_prompt = str(lesson_core["reflectionPrompt"]) + worked_example = cast(GroundedWorkedExample, lesson_core["workedExample"]) + guided_practice = cast(List[str], lesson_core["guidedPractice"]) + independent_practice = cast(List[str], lesson_core["independentPractice"]) + quick_assessment = cast(List[str], lesson_core["quickAssessment"]) + source_citations = cast(List[str], lesson_core["sourceCitations"]) + needs_review = bool(lesson_core["needsReview"]) + review_reason = str(lesson_core["reviewReason"]) + retrieved_evidence = cast(List[CurriculumEvidenceSource], lesson_core["retrievedEvidence"]) + + if needs_review: + warnings.append(review_reason) + + block_data = _build_lesson_generation_blocks( + lesson_title=lesson_title, + lesson_objective=lesson_objective, + real_world_hook=real_world_hook, + explanation=explanation, + worked_example=worked_example, + guided_practice=guided_practice, + independent_practice=independent_practice, + quick_assessment=quick_assessment, + reflection_prompt=reflection_prompt, + retrieved_evidence=retrieved_evidence, + selected_topics=selected_topics, + ) + provenance_summary = cast(List[Dict[str, Optional[str]]], block_data["provenanceSummary"]) + blocks = cast(List[LessonPlanBlock], block_data["blocks"]) + + self_validation_report = await _validate_generated_lesson_plan( + lesson_title=lesson_title, + selected_topics=selected_topics, + blocks=blocks, + ) + if not self_validation_report.get("passed"): + warnings.append("Generated lesson failed self-validation checks.") + if not request.allowUnverifiedLesson: + raise HTTPException( + status_code=422, + detail={ + "message": "Lesson self-validation failed. Fix generation inputs or enable allowUnverifiedLesson.", + "selfValidation": self_validation_report, + }, + ) + + publish_ready = bool( + self_validation_report.get("passed") + and source_legitimacy_report.get("status") == "verified" + and retrieval_band != "low" + and not needs_review + ) + return LessonPlanResponse( + success=True, + lessonTitle=lesson_title, + curriculumCompetency=curriculum_competency, + lessonObjective=lesson_objective, + realWorldHook=real_world_hook, + explanation=explanation, + workedExample=worked_example, + guidedPractice=guided_practice, + independentPractice=independent_practice, + quickAssessment=quick_assessment, + reflectionPrompt=reflection_prompt, + sourceCitations=source_citations, + retrievedEvidence=retrieved_evidence, + curriculumGrounding=CurriculumGroundingSummary( + query=retrieval_query, + confidence=retrieval_confidence, + confidenceBand=retrieval_band, + retrievedChunks=len(curriculum_chunks), + needsReview=needs_review, + issues=sorted({*(retrieval_issues or []), *(source_legitimacy_report.get("issues") or [])}), + ), + gradeLevel=request.gradeLevel, + classSectionId=request.classSectionId, + className=request.className, + subject=requested_subject, + quarter=requested_quarter, + moduleUnit=request.moduleUnit, + learnerLevel=request.learnerLevel, + usedImportedTopics=bool(imported_topic_titles), + importedTopicCount=len(imported_topics_payload.get("topics") or []), + weakSignals=class_signals, + focusTopics=selected_topics, + blocks=blocks, + provenanceSummary=provenance_summary, + sourceLegitimacy=SourceLegitimacyReport(**source_legitimacy_report), + selfValidation=LessonSelfValidationReport(**self_validation_report), + publishReady=publish_ready, + needsReview=needs_review, + reviewReason=review_reason or None, + warnings=warnings, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Lesson generation error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Lesson generation error: {str(e)}") + + +@app.post("/api/lesson/generate-async", response_model=AsyncTaskSubmitResponse) +async def generate_lesson_plan_async(http_request: Request, request: LessonGenerationRequest): + if not ENABLE_ASYNC_GENERATION: + raise HTTPException(status_code=404, detail="Async generation is disabled") + + enforce_rate_limit(http_request, "generate_lesson_plan_async", 20, 60) + user = get_current_user(http_request) + if user.role not in TEACHER_OR_ADMIN: + raise HTTPException(status_code=403, detail="Forbidden for this role") + + task_id = _create_async_task( + owner_uid=user.uid, + task_kind="lesson_generation", + payload=request.model_dump(), + ) + + _start_async_task_in_thread(task_id, lambda: generate_lesson_plan(http_request, request)) + + with _async_tasks_lock: + task = dict(_async_tasks.get(task_id, {})) + + return AsyncTaskSubmitResponse( + success=True, + taskId=task_id, + status=str(task.get("status") or "queued"), + taskKind="lesson_generation", + createdAt=str(task.get("createdAt") or _utc_now_iso()), + ) + + +@app.post("/api/feedback/import-grounded", response_model=ImportGroundedFeedbackResponse) +async def record_import_grounded_feedback(http_request: Request, request: ImportGroundedFeedbackRequest): + """Capture lightweight pilot feedback telemetry for import-grounded quiz and lesson flows.""" + warnings: List[str] = [] + stored = False + try: + enforce_rate_limit(http_request, "import_grounded_feedback", 60, 60) + user = get_current_user(http_request) + + if ENABLE_IMPORT_GROUNDED_FEEDBACK_EVENTS and _firebase_ready and firebase_firestore: + payload: Dict[str, Any] = { + "flow": request.flow, + "status": request.status, + "teacherId": user.uid, + "teacherEmail": user.email, + "role": user.role, + "classSectionId": (request.classSectionId or "").strip() or None, + "className": (request.className or "").strip() or None, + "metadata": request.metadata, + "createdAt": FIRESTORE_SERVER_TIMESTAMP, + "createdAtIso": datetime.now(timezone.utc).isoformat(), + } + firebase_firestore.client().collection("importGroundedFeedbackEvents").add(payload) + stored = True + else: + warnings.append("Import-grounded feedback storage is disabled or unavailable.") + + _write_access_audit_log( + http_request, + action="import_grounded_feedback", + status="success" if stored else "accepted", + class_section_id=request.classSectionId, + metadata={ + "flow": request.flow, + "eventStatus": request.status, + "stored": stored, + }, + ) + + return ImportGroundedFeedbackResponse(success=True, stored=stored, warnings=warnings) + except HTTPException: + raise + except Exception as e: + logger.warning(f"Import-grounded feedback logging failed: {e}") + return ImportGroundedFeedbackResponse(success=True, stored=False, warnings=["Feedback logging failed"]) + + +@app.get("/api/feedback/import-grounded/summary", response_model=ImportGroundedTelemetrySummaryResponse) +async def get_import_grounded_feedback_summary( + request: Request, + classSectionId: Optional[str] = Query(default=None), + days: int = Query(default=7, ge=1, le=30), + limit: int = Query(default=5000, ge=100, le=20000), +): + """Aggregate import-grounded pilot telemetry (Query A-D equivalent) from Firestore events.""" + try: + user = get_current_user(request) + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + normalized_class_section_id = (classSectionId or "").strip() or None + now_utc = datetime.now(timezone.utc) + lookback_start = now_utc - timedelta(days=days) + start_24h = now_utc - timedelta(hours=24) + warnings: List[str] = [] + + docs = ( + firebase_firestore.client() + .collection("importGroundedFeedbackEvents") + .where("teacherId", "==", user.uid) + .limit(limit) + .stream() + ) + + hourly_counter: Counter[Tuple[str, str, str]] = Counter() + class_counter: Dict[str, Dict[str, int]] = {} + flow_counter: Dict[str, Dict[str, int]] = {} + error_counter: Counter[str] = Counter() + total_events = 0 + + for doc in docs: + payload = doc.to_dict() or {} + event_class_section_id = str(payload.get("classSectionId") or "").strip() or None + if normalized_class_section_id and event_class_section_id != normalized_class_section_id: + continue + + event_ts = _coerce_event_timestamp_utc(payload) + if event_ts is None: + warnings.append("Some events were excluded because timestamps were missing or invalid.") + continue + if event_ts < lookback_start: + continue + + flow = str(payload.get("flow") or "unknown").strip().lower() or "unknown" + status = str(payload.get("status") or "unknown").strip().lower() or "unknown" + raw_metadata = payload.get("metadata") + metadata: Dict[str, Any] = raw_metadata if isinstance(raw_metadata, dict) else {} + + total_events += 1 + + hour_bucket = event_ts.replace(minute=0, second=0, microsecond=0).isoformat() + hourly_counter[(hour_bucket, flow, status)] += 1 + + class_key = event_class_section_id or "unscoped" + class_stats = class_counter.setdefault( + class_key, + { + "total24h": 0, + "failed24h": 0, + "skipped24h": 0, + "total7d": 0, + "failed7d": 0, + "skipped7d": 0, + }, + ) + class_stats["total7d"] += 1 + if status == "failed": + class_stats["failed7d"] += 1 + if status == "skipped": + class_stats["skipped7d"] += 1 + if event_ts >= start_24h: + class_stats["total24h"] += 1 + if status == "failed": + class_stats["failed24h"] += 1 + if status == "skipped": + class_stats["skipped24h"] += 1 + + flow_stats = flow_counter.setdefault( + flow, + { + "totalEvents": 0, + "eligibleEvents": 0, + "groundedEvents": 0, + }, + ) + flow_stats["totalEvents"] += 1 + import_grounding_enabled = bool(metadata.get("importGroundingEnabled", True)) + if import_grounding_enabled: + flow_stats["eligibleEvents"] += 1 + if bool(metadata.get("usedImportedTopics", False)): + flow_stats["groundedEvents"] += 1 + + if status == "failed": + normalized_error = str(metadata.get("error") or "unknown_error").strip().lower() or "unknown_error" + error_counter[normalized_error] += 1 + + deduped_warnings = sorted(set(warnings)) + + hourly_volume = [ + ImportGroundedHourlyVolumeItem( + hourBucket=hour, + flow=flow, + status=status, + eventCount=count, + ) + for (hour, flow, status), count in sorted(hourly_counter.items(), key=lambda item: item[0], reverse=True) + ] + + class_rates: List[ImportGroundedClassRateItem] = [] + aggregate_total_24h = 0 + aggregate_failed_24h = 0 + aggregate_skipped_24h = 0 + aggregate_total_7d = 0 + aggregate_failed_7d = 0 + aggregate_skipped_7d = 0 + + for class_key, stats in sorted(class_counter.items()): + total24h = int(stats["total24h"]) + failed24h = int(stats["failed24h"]) + skipped24h = int(stats["skipped24h"]) + total7d = int(stats["total7d"]) + failed7d = int(stats["failed7d"]) + skipped7d = int(stats["skipped7d"]) + + aggregate_total_24h += total24h + aggregate_failed_24h += failed24h + aggregate_skipped_24h += skipped24h + aggregate_total_7d += total7d + aggregate_failed_7d += failed7d + aggregate_skipped_7d += skipped7d + + class_rates.append( + ImportGroundedClassRateItem( + classSectionId=class_key, + total24h=total24h, + failed24h=failed24h, + skipped24h=skipped24h, + failureRate24h=(failed24h / total24h) if total24h else 0.0, + skippedRate24h=(skipped24h / total24h) if total24h else 0.0, + total7d=total7d, + failed7d=failed7d, + skipped7d=skipped7d, + failureRate7d=(failed7d / total7d) if total7d else 0.0, + skippedRate7d=(skipped7d / total7d) if total7d else 0.0, + ) + ) + + flow_usage: List[ImportGroundedFlowUsageItem] = [] + aggregate_eligible = 0 + aggregate_grounded = 0 + for flow, stats in sorted(flow_counter.items()): + eligible = int(stats["eligibleEvents"]) + grounded = int(stats["groundedEvents"]) + aggregate_eligible += eligible + aggregate_grounded += grounded + flow_usage.append( + ImportGroundedFlowUsageItem( + flow=flow, + totalEvents=int(stats["totalEvents"]), + eligibleEvents=eligible, + groundedEvents=grounded, + groundedUsageRatio=(grounded / eligible) if eligible else 0.0, + ) + ) + + top_errors = [ + ImportGroundedErrorReasonItem(normalizedErrorReason=reason, occurrences=count) + for reason, count in error_counter.most_common(20) + ] + + failure_rate_24h = (aggregate_failed_24h / aggregate_total_24h) if aggregate_total_24h else 0.0 + skipped_rate_7d = (aggregate_skipped_7d / aggregate_total_7d) if aggregate_total_7d else 0.0 + failure_rate_7d = (aggregate_failed_7d / aggregate_total_7d) if aggregate_total_7d else 0.0 + grounded_usage_ratio = (aggregate_grounded / aggregate_eligible) if aggregate_eligible else 0.0 + + threshold_reasons: List[str] = [] + if failure_rate_24h > 0.08: + threshold_reasons.append("Hold: failure_rate_24h exceeded 8% threshold.") + if failure_rate_7d > 0.05: + threshold_reasons.append("Hold: failure_rate_7d exceeded 5% threshold.") + if skipped_rate_7d > 0.10: + threshold_reasons.append("Hold: skipped_rate_7d exceeded 10% threshold.") + if aggregate_eligible > 0 and grounded_usage_ratio < 0.70: + threshold_reasons.append("Hold: grounded_usage_ratio below 70% for eligible events.") + + if total_events == 0: + deduped_warnings.append("No telemetry events found for the requested window/filter.") + + _write_access_audit_log( + request, + action="import_grounded_feedback_summary_read", + status="success", + class_section_id=normalized_class_section_id, + metadata={ + "lookbackDays": days, + "requestedLimit": limit, + "returnedEvents": total_events, + "warningsCount": len(deduped_warnings), + }, + ) + + return ImportGroundedTelemetrySummaryResponse( + success=True, + classSectionId=normalized_class_section_id, + lookbackDays=days, + totalEvents=total_events, + hourlyVolume=hourly_volume, + classRates=class_rates, + flowUsage=flow_usage, + topErrors=top_errors, + thresholds=ImportGroundedTelemetryThresholds( + go=len(threshold_reasons) == 0, + reasons=threshold_reasons, + ), + warnings=deduped_warnings, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Import-grounded feedback summary lookup failed: {e}") + raise HTTPException(status_code=500, detail=f"Import-grounded feedback summary error: {str(e)}") + + +@app.get("/api/import-grounded/access-audit", response_model=ImportGroundedAccessAuditResponse) +async def get_import_grounded_access_audit( + request: Request, + classSectionId: Optional[str] = Query(default=None), + days: int = Query(default=7, ge=1, le=30), + limit: int = Query(default=200, ge=1, le=1000), + export: str = Query(default="json"), +): + """ + Retrieve import-grounded access audit events for the authenticated teacher scope. + Supports JSON (default) and CSV export. + """ + try: + user = get_current_user(request) + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + export_mode = (export or "json").strip().lower() + if export_mode not in {"json", "csv"}: + raise HTTPException(status_code=400, detail="export must be one of: json, csv") + + normalized_class_section_id = (classSectionId or "").strip() or None + lookback_start = datetime.now(timezone.utc) - timedelta(days=days) + warnings: List[str] = [] + + query = ( + firebase_firestore.client() + .collection("accessAuditLogs") + .where("teacherId", "==", user.uid) + ) + try: + docs = ( + query + .order_by("createdAt", direction=FIRESTORE_QUERY_DESCENDING) + .limit(min(limit * 4, 4000)) + .stream() + ) + except Exception: + warnings.append("Access-audit lookup used fallback query path without ordering.") + docs = query.limit(min(limit * 4, 4000)).stream() + + allowed_prefixes = ( + "class_records_", + "course_material_", + "risk_refresh_", + "import_grounded_", + ) + entries: List[ImportGroundedAccessAuditItem] = [] + by_action: Counter[str] = Counter() + by_status: Counter[str] = Counter() + + for doc in docs: + payload = doc.to_dict() or {} + action = str(payload.get("action") or "").strip() + if not action or not action.startswith(allowed_prefixes): + continue + + event_class_section_id = str(payload.get("classSectionId") or "").strip() or None + if normalized_class_section_id and event_class_section_id != normalized_class_section_id: + continue + + event_ts = _coerce_event_timestamp_utc(payload) + if event_ts is None: + warnings.append("Some audit events were excluded because timestamps were missing or invalid.") + continue + if event_ts < lookback_start: + continue + + status = str(payload.get("status") or "unknown").strip() or "unknown" + method = str(payload.get("method") or "").strip().upper() or "GET" + path = str(payload.get("path") or "").strip() or "unknown" + metadata_raw = payload.get("metadata") + metadata = metadata_raw if isinstance(metadata_raw, dict) else {} + created_at_iso = event_ts.isoformat() + + entry = ImportGroundedAccessAuditItem( + auditId=str(doc.id), + action=action, + status=status, + path=path, + method=method, + classSectionId=event_class_section_id, + createdAtIso=created_at_iso, + metadata=metadata, + ) + entries.append(entry) + by_action[action] += 1 + by_status[status] += 1 + + if len(entries) >= limit: + break + + if not entries: + warnings.append("No import-grounded access-audit events found for the requested window/filter.") + + deduped_warnings = sorted(set(warnings)) + summary = ImportGroundedAccessAuditSummary( + totalEvents=len(entries), + byAction=dict(by_action), + byStatus=dict(by_status), + ) + + _write_access_audit_log( + request, + action="import_grounded_access_audit_read", + status="success", + class_section_id=normalized_class_section_id, + metadata={ + "lookbackDays": days, + "requestedLimit": limit, + "returnedEvents": len(entries), + "export": export_mode, + "warningsCount": len(deduped_warnings), + }, + ) + + if export_mode == "csv": + header = [ + "auditId", + "createdAtIso", + "action", + "status", + "method", + "path", + "classSectionId", + "metadataJson", + ] + lines = [",".join(header)] + for item in entries: + lines.append( + ",".join( + [ + _csv_escape(item.auditId), + _csv_escape(item.createdAtIso), + _csv_escape(item.action), + _csv_escape(item.status), + _csv_escape(item.method), + _csv_escape(item.path), + _csv_escape(item.classSectionId or ""), + _csv_escape(_to_compact_json(item.metadata)), + ] + ) + ) + + date_tag = datetime.now(timezone.utc).strftime("%Y%m%d") + return Response( + content="\n".join(lines), + media_type="text/csv", + headers={ + "Content-Disposition": f'attachment; filename="import-grounded-access-audit-{date_tag}.csv"', + }, + ) + + return ImportGroundedAccessAuditResponse( + success=True, + classSectionId=normalized_class_section_id, + lookbackDays=days, + entries=entries, + summary=summary, + warnings=deduped_warnings, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Import-grounded access-audit lookup failed: {e}") + raise HTTPException(status_code=500, detail=f"Import-grounded access-audit error: {str(e)}") + + +# ─── Quiz Generation System Prompt ──────────────────────────── + +QUIZ_GENERATION_SYSTEM_PROMPT = """You are an expert math quiz generator for MathPulse AI, an educational platform. + +PURPOSE: +You are creating supplemental math assessments to support classroom learning, not replace teacher instruction. + +BLOOM'S TAXONOMY FRAMEWORK: +Generate questions following Bloom's Taxonomy levels to ensure comprehensive skill evaluation: +- Remember (recall): Recall facts, definitions, or formulas +- Understand (explain): Explain concepts in own words, interpret data +- Apply (use): Use formulas/methods to solve problems in new contexts +- Analyze (examine): Break down complex problems, compare approaches, identify patterns + +QUESTION TYPES: +- Identification: Define or identify mathematical concepts, properties, or theorems +- Enumeration: List steps in a process, properties of a shape, or related concepts +- Multiple Choice: Standard multiple-choice with 4 options (one correct) +- Word Problem: Real-world context-based problems relatable to students' experiences +- Equation-Based: Solve equations, manipulate expressions, prove identities + +GRAPH QUESTIONS (when requested): +- Use ONLY identification-type questions about graphs +- Ask students to identify key features: intercepts, slopes, vertex locations, asymptotes, domain/range, transformations +- Describe the graph in text (do NOT attempt to render images) +- Format: "Given a graph of [description with key coordinates]... Identify [feature]." +- Note: Graph questions use identification format as graphing is challenging for students + +GUIDELINES: +- Make questions context-based and relatable to students' real-world experiences +- Generate clear, unambiguous questions with definitive correct answers +- For each question, provide a detailed step-by-step explanation +- Ensure mathematical accuracy — verify all answers +- Match difficulty to the specified level (easy, medium, hard) +- Distribute Bloom's Taxonomy levels as evenly as possible across the quiz + +RESPONSE FORMAT: +Respond ONLY with a valid JSON array of question objects. No markdown, no explanation outside: +[ + { + "questionType": "multiple_choice", + "question": "...", + "correctAnswer": "...", + "options": ["A) ...", "B) ...", "C) ...", "D) ..."], + "bloomLevel": "apply", + "difficulty": "medium", + "topic": "Linear Equations", + "points": 3, + "explanation": "Step 1: ... Step 2: ... Therefore the answer is ..." + } +] + +Points by difficulty: easy=1, medium=3, hard=5. +For non-multiple-choice questions, omit the "options" field or set to null. +Do NOT output chain-of-thought, planning notes, "Thinking Process", or any preamble. +If you cannot comply with constraints, still return the closest valid JSON array only. +""" + + +# ─── Quiz Generation Helpers ────────────────────────────────── + + +def _distribute_questions( + num_questions: int, + difficulty_distribution: Dict[str, int], + bloom_levels: List[str], + question_types: List[str], +) -> List[Dict[str, str]]: + """ + Pre-compute the distribution of questions by difficulty, Bloom level, + and question type so the LLM prompt can be very specific. + """ + distribution: List[Dict[str, str]] = [] + + # Compute counts per difficulty + difficulty_counts: Dict[str, int] = {} + remaining = num_questions + for i, (diff, pct) in enumerate(difficulty_distribution.items()): + if i == len(difficulty_distribution) - 1: + difficulty_counts[diff] = remaining + else: + count = max(1, round(num_questions * pct / 100)) + count = min(count, remaining) + difficulty_counts[diff] = count + remaining -= count + + idx = 0 + for diff, count in difficulty_counts.items(): + for j in range(count): + bloom = bloom_levels[idx % len(bloom_levels)] + qtype = question_types[idx % len(question_types)] + distribution.append({ + "difficulty": diff, + "bloomLevel": bloom, + "questionType": qtype, + }) + idx += 1 + + return distribution + + +def _parse_quiz_json(raw: str) -> List[Dict[str, Any]]: + """Robustly extract a JSON array of quiz questions from LLM output.""" + cleaned = raw.strip() + # Remove markdown fences + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + cleaned = cleaned.strip() + + # Remove known reasoning wrappers and preambles that can precede JSON. + cleaned = re.sub(r"[\s\S]*?", "", cleaned, flags=re.IGNORECASE) + # Remove common reasoning preambles before JSON output. + cleaned = re.sub(r"^\s*thinking\s*process\s*:\s*", "", cleaned, flags=re.IGNORECASE) + cleaned = re.sub(r"^\s*json\s*[:\-]?\s*", "", cleaned, flags=re.IGNORECASE) + + def _extract_json_blocks(text: str) -> List[str]: + blocks: List[str] = [] + starts = [i for i, ch in enumerate(text) if ch in "[{"] + for start in starts: + opener = text[start] + closer = "]" if opener == "[" else "}" + depth = 0 + in_string = False + escaped = False + for idx in range(start, len(text)): + ch = text[idx] + if in_string: + if escaped: + escaped = False + elif ch == "\\": + escaped = True + elif ch == '"': + in_string = False + continue + + if ch == '"': + in_string = True + continue + + if ch == opener: + depth += 1 + elif ch == closer: + depth -= 1 + if depth == 0: + blocks.append(text[start : idx + 1]) + break + return blocks + + def _normalize_candidate(candidate: str) -> str: + normalized = candidate.strip().lstrip("\ufeff") + normalized = ( + normalized + .replace("\u201c", '"') + .replace("\u201d", '"') + .replace("\u2018", "'") + .replace("\u2019", "'") + ) + # Remove trailing commas before object/array closers. + normalized = re.sub(r",(\s*[}\]])", r"\1", normalized) + return normalized + + def _jsonish_loads(candidate: str) -> Any: + normalized = _normalize_candidate(candidate) + try: + return json.loads(normalized) + except json.JSONDecodeError: + pass + + # Fallback for Python-literal style payloads using single quotes / True / None. + python_like = re.sub(r"\btrue\b", "True", normalized, flags=re.IGNORECASE) + python_like = re.sub(r"\bfalse\b", "False", python_like, flags=re.IGNORECASE) + python_like = re.sub(r"\bnull\b", "None", python_like, flags=re.IGNORECASE) + try: + return ast.literal_eval(python_like) + except (ValueError, SyntaxError): + return None + + def _coerce_question_list(value: Any) -> List[Dict[str, Any]]: + if not isinstance(value, list): + return [] + return [item for item in value if isinstance(item, dict)] + + def _extract_question_list(payload: Any) -> List[Dict[str, Any]]: + if isinstance(payload, list): + return _coerce_question_list(payload) + + if isinstance(payload, dict): + for key in ("questions", "quiz", "items", "data"): + nested = payload.get(key) + if isinstance(nested, list): + return _coerce_question_list(nested) + + return [] + + # Try every balanced JSON-ish block and accept known payload wrappers. + for candidate in _extract_json_blocks(cleaned): + parsed = _jsonish_loads(candidate) + extracted = _extract_question_list(parsed) + if extracted: + return extracted + + # Try to find array brackets + arr_start = cleaned.find("[") + arr_end = cleaned.rfind("]") + 1 + if arr_start >= 0 and arr_end > arr_start: + parsed = _jsonish_loads(cleaned[arr_start:arr_end]) + extracted = _extract_question_list(parsed) + if extracted: + return extracted + + # Fallback: salvage individually parseable question-like objects. + objects: List[Dict[str, Any]] = [] + for candidate in _extract_json_blocks(cleaned): + if not candidate.lstrip().startswith("{"): + continue + parsed = _jsonish_loads(candidate) + if isinstance(parsed, dict) and str(parsed.get("question") or "").strip(): + objects.append(parsed) + + return objects + + +async def _repair_quiz_json_with_llm( + raw_output: str, + *, + num_questions: int, + timeout: int, + model_id: Optional[str], +) -> List[Dict[str, Any]]: + """Attempt to coerce malformed quiz output into a strict JSON array.""" + repair_messages = [ + { + "role": "system", + "content": ( + "You are a strict JSON formatter. Convert the user's draft into ONLY a valid JSON array " + "of quiz question objects. Do not include markdown, commentary, or chain-of-thought. " + f"Return exactly {num_questions} objects when possible." + ), + }, + { + "role": "user", + "content": ( + "Rewrite this into valid JSON array only. Keep and normalize useful question content.\n\n" + f"DRAFT:\n{raw_output}" + ), + }, + ] + + repaired_text = await call_hf_chat_async( + repair_messages, + max_tokens=min(4096, max(2048, num_questions * 260)), + temperature=0.0, + top_p=0.1, + timeout=timeout, + task_type="default", + model=model_id, + ) + return _parse_quiz_json(repaired_text) + + +async def _regenerate_quiz_json_strict( + *, + original_prompt: str, + num_questions: int, + timeout: int, + model_id: Optional[str], +) -> List[Dict[str, Any]]: + """Regenerate quiz questions using a strict JSON-only prompt path.""" + strict_messages = [ + { + "role": "system", + "content": ( + "You generate quizzes as strict JSON only. " + "Output must begin with '[' and end with ']'. " + "Do not output markdown, notes, or chain-of-thought." + ), + }, + { + "role": "user", + "content": ( + f"Generate exactly {num_questions} questions from this spec. " + "Return JSON array only.\n\n" + f"SPEC:\n{original_prompt}" + ), + }, + ] + + strict_text = await call_hf_chat_async( + strict_messages, + max_tokens=min(4096, max(2048, num_questions * 260)), + temperature=0.0, + top_p=0.1, + timeout=timeout, + task_type="default", + model=model_id, + ) + return _parse_quiz_json(strict_text) + + +def _validate_quiz_questions( + questions: List[Dict[str, Any]], + distribution: List[Dict[str, str]], + topic_provenance_map: Optional[Dict[str, Dict[str, Optional[str]]]] = None, +) -> List[QuizQuestion]: + """Validate and normalise each question from the LLM response.""" + validated: List[QuizQuestion] = [] + points_map = {"easy": 1, "medium": 3, "hard": 5} + + def _topic_key(value: str) -> str: + normalized = re.sub(r"[^a-z0-9\s]+", " ", (value or "").lower()) + normalized = re.sub(r"\s+", " ", normalized).strip() + return normalized + + def _normalize_options(raw_options: Any, fallback_obj: Dict[str, Any]) -> Optional[List[str]]: + def _coerce_list(value: Any) -> List[str]: + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + if isinstance(value, dict): + extracted = [str(item).strip() for item in value.values() if str(item).strip()] + return extracted + if isinstance(value, str): + text = value.strip() + if not text: + return [] + if text.startswith("[") and text.endswith("]"): + try: + parsed = json.loads(text) + if isinstance(parsed, list): + return [str(item).strip() for item in parsed if str(item).strip()] + except Exception: + pass + split_candidates = [part.strip() for part in re.split(r"\n+|\s*\|\s*", text) if part.strip()] + return split_candidates + return [] + + options_list = _coerce_list(raw_options) + if not options_list: + for alt_key in ("choices", "answerChoices", "answers"): + options_list = _coerce_list(fallback_obj.get(alt_key)) + if options_list: + break + + if not options_list: + return None + + return options_list[:4] + + for i, q in enumerate(questions): + if not isinstance(q, dict): + continue + dist = distribution[i] if i < len(distribution) else {} + + question_type = q.get("questionType", dist.get("questionType", "identification")) + if question_type not in VALID_QUESTION_TYPES: + question_type = "identification" + + difficulty = q.get("difficulty", dist.get("difficulty", "medium")) + if difficulty not in VALID_DIFFICULTY_LEVELS: + difficulty = "medium" + + bloom_level = q.get("bloomLevel", dist.get("bloomLevel", "understand")) + if bloom_level not in VALID_BLOOM_LEVELS: + bloom_level = "understand" + + options = None + if question_type == "multiple_choice": + options = _normalize_options(q.get("options"), q) + if not options: + # If options are malformed/missing, downgrade to identification to keep generation resilient. + question_type = "identification" + options = None + + question_topic = str(q.get("topic", "General")) + question_provenance = None + if topic_provenance_map: + question_provenance = topic_provenance_map.get(_topic_key(question_topic)) + + validated.append(QuizQuestion( + questionType=question_type, + question=str(q.get("question", "")), + correctAnswer=str(q.get("correctAnswer", "")), + options=options, + bloomLevel=bloom_level, + difficulty=difficulty, + topic=question_topic, + points=q.get("points", points_map.get(difficulty, 3)), + explanation=str(q.get("explanation", "No explanation provided.")), + provenance=question_provenance, + )) + + return validated + + +# ─── Quiz Generation Endpoints ──────────────────────────────── + + +@app.post("/api/quiz/generate", response_model=QuizResponse) +async def generate_quiz(http_request: Request, request: QuizGenerationRequest): + """ + Generate an AI-powered quiz via HF Serverless Inference. + Supports Bloom's Taxonomy integration, multiple question types, + and graph-based identification questions. + """ + try: + + normalized_exclude_topics = set(_canonicalize_topic_list(request.excludeTopics)) + # Filter out excluded topics (supports legacy topic labels via canonicalization) + effective_topics = _canonicalize_topic_list(request.topics) + effective_topics = [t for t in effective_topics if t not in normalized_exclude_topics] + import_grounding_enabled = ENABLE_IMPORT_GROUNDED_QUIZ + import_warnings: List[str] = [] + if not import_grounding_enabled and request.preferImportedTopics: + import_warnings.append( + "Import-grounded quiz generation is disabled by rollout flag; using provided and curriculum topics only." + ) + + imported_topics_payload: Dict[str, Any] = {"topics": [], "materials": [], "warnings": []} + imported_topic_titles: List[str] = [] + if import_grounding_enabled and (request.preferImportedTopics or not effective_topics): + imported_topics_payload = _load_persisted_course_material_topics( + http_request, + class_section_id=request.classSectionId, + material_id=request.materialId, + limit_materials=15, + ) + import_warnings.extend(imported_topics_payload.get("warnings", [])) + imported_topic_titles = [ + str(topic.get("title", "")).strip() + for topic in imported_topics_payload.get("topics", []) + if str(topic.get("title", "")).strip() and _canonicalize_topic_label(str(topic.get("title", "")).strip()) not in normalized_exclude_topics + ] + imported_topic_titles = _canonicalize_topic_list(imported_topic_titles) + + if imported_topic_titles: + if request.preferImportedTopics: + merged_topics = imported_topic_titles + [topic for topic in effective_topics if topic not in imported_topic_titles] + else: + merged_topics = effective_topics + [topic for topic in imported_topic_titles if topic not in effective_topics] + effective_topics = merged_topics + + if not effective_topics: + raise HTTPException( + status_code=400, + detail="All requested topics are in the exclude list. Please provide at least one topic to cover.", + ) + + # ── Enforce request limits ── + if len(effective_topics) > MAX_TOPICS_LIMIT: + logger.warning( + f"Trimming topics from {len(effective_topics)} to {MAX_TOPICS_LIMIT} (request limit)" + ) + effective_topics = effective_topics[:MAX_TOPICS_LIMIT] + + if request.numQuestions > MAX_QUESTIONS_LIMIT: + logger.warning( + f"Clamping numQuestions from {request.numQuestions} to {MAX_QUESTIONS_LIMIT} (request limit)" + ) + request.numQuestions = MAX_QUESTIONS_LIMIT + + # Pre-compute question distribution + distribution = _distribute_questions( + request.numQuestions, + request.difficultyDistribution, + request.bloomLevels, + request.questionTypes, + ) + + # Build per-question specifications + spec_lines: List[str] = [] + for i, d in enumerate(distribution): + topic = effective_topics[i % len(effective_topics)] + graph_note = "" + if request.includeGraphs and d["questionType"] == "identification": + graph_note = " (GRAPH-BASED: describe a graph and ask the student to identify a feature)" + spec_lines.append( + f"Q{i+1}: type={d['questionType']}, difficulty={d['difficulty']}, " + f"bloom={d['bloomLevel']}, topic={topic}{graph_note}" + ) + + graph_instruction = "" + if request.includeGraphs: + graph_instruction = ( + "\n\nGRAPH QUESTIONS: For any identification questions, make them graph-based. " + "Describe the graph verbally (e.g., 'Given a parabola with vertex at (2,3) opening upward...') " + "and ask the student to identify key features such as intercepts, axis of symmetry, " + "slopes, asymptotes, domain, range, or transformations. " + "Do NOT attempt to render an actual image." + ) + + prompt = f"""Generate exactly {request.numQuestions} math quiz questions for {request.gradeLevel} students. + +Topics to cover: {', '.join(effective_topics)} + +Question specifications: +{chr(10).join(spec_lines)} +{graph_instruction} + +Remember: +- Points: easy=1, medium=3, hard=5 +- Each question must have a step-by-step explanation +- Multiple choice must have exactly 4 options +- All math must be accurate +- Make problems relatable to students' real-world experiences""" + + messages = [ + {"role": "system", "content": QUIZ_GENERATION_SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] + + logger.info(f"Generating quiz: {request.numQuestions} questions, topics={effective_topics}") + + # Scale max_tokens based on requested questions and allow larger completions + # for higher-count quizzes while keeping provider-side limits reasonable. + max_tokens = min(8192, max(3072, request.numQuestions * 320)) + # Use longer HTTP timeout for larger quiz payloads. + http_timeout = min(300, max(120, request.numQuestions * 10)) + + parsed_questions: List[Dict[str, Any]] = [] + raw_content = "" # Will be set inside the loop + max_attempts = 3 # Extra retry helps with larger quiz sizes + + for attempt in range(max_attempts): + raw_content = await call_hf_chat_async( + messages, max_tokens=max_tokens, temperature=0.1, top_p=0.9, + timeout=http_timeout, + task_type="quiz_generation", + model=HF_QUIZ_MODEL_ID, + ) + logger.info(f"Raw quiz response length: {len(raw_content)} chars (attempt {attempt + 1})") + + parsed_questions = _parse_quiz_json(raw_content) + + if not parsed_questions: + logger.error(f"Failed to parse quiz JSON (attempt {attempt + 1}). Raw content:\n{raw_content[:500]}") + try: + repaired_questions = await _repair_quiz_json_with_llm( + raw_content, + num_questions=request.numQuestions, + timeout=http_timeout, + model_id=HF_QUIZ_JSON_REPAIR_MODEL_ID, + ) + except Exception as repair_exc: + logger.warning(f"Quiz JSON repair pass failed (attempt {attempt + 1}): {repair_exc}") + repaired_questions = [] + + if repaired_questions: + parsed_questions = repaired_questions + logger.info( + "Recovered quiz JSON via repair pass: %s questions (attempt %s)", + len(parsed_questions), + attempt + 1, + ) + + if not parsed_questions: + try: + strict_questions = await _regenerate_quiz_json_strict( + original_prompt=prompt, + num_questions=request.numQuestions, + timeout=http_timeout, + model_id=HF_QUIZ_JSON_REPAIR_MODEL_ID, + ) + except Exception as strict_exc: + logger.warning(f"Strict quiz regeneration failed (attempt {attempt + 1}): {strict_exc}") + strict_questions = [] + + if strict_questions: + parsed_questions = strict_questions + logger.info( + "Recovered quiz JSON via strict regeneration: %s questions (attempt %s)", + len(parsed_questions), + attempt + 1, + ) + + if not parsed_questions: + if attempt < max_attempts - 1: + logger.info("Retrying quiz generation...") + continue + raise HTTPException( + status_code=500, + detail="Failed to parse quiz questions from AI response. Please try again.", + ) + + # If we got at least 70% of requested questions, accept the result + if len(parsed_questions) >= request.numQuestions * 0.7: + break + + # Otherwise retry with a stronger nudge + if attempt < max_attempts - 1: + logger.warning( + f"LLM generated only {len(parsed_questions)}/{request.numQuestions} questions " + f"(attempt {attempt + 1}). Retrying with reinforced prompt..." + ) + # Add an assistant + user turn to push the LLM harder + messages = [ + {"role": "system", "content": QUIZ_GENERATION_SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": raw_content}, + { + "role": "user", + "content": ( + f"You only generated {len(parsed_questions)} questions but I need " + f"exactly {request.numQuestions}. Please generate ALL " + f"{request.numQuestions} questions in a single JSON array. " + f"Do not stop early." + ), + }, + ] + + # Warn if the LLM still generated fewer questions than requested + if len(parsed_questions) < request.numQuestions: + logger.warning( + f"LLM generated {len(parsed_questions)}/{request.numQuestions} questions " + f"after {max_attempts} attempts (raw length={len(raw_content)} chars)." + ) + + topic_provenance_map: Dict[str, Dict[str, Optional[str]]] = {} + for imported_topic in (imported_topics_payload.get("topics") or []): + title = str(imported_topic.get("title") or "").strip() + if not title: + continue + normalized_title = re.sub(r"[^a-z0-9\s]+", " ", title.lower()) + normalized_title = re.sub(r"\s+", " ", normalized_title).strip() + topic_provenance_map[normalized_title] = { + "topicId": str(imported_topic.get("topicId") or "") or None, + "title": title, + "materialId": str(imported_topic.get("materialId") or "") or None, + "sourceFile": str(imported_topic.get("sourceFile") or "") or None, + "sectionId": str(imported_topic.get("sectionId") or "") or None, + } + + validated = _validate_quiz_questions( + parsed_questions, + distribution, + topic_provenance_map=topic_provenance_map, + ) + total_points = sum(q.points for q in validated) + + # Build metadata + topic_counts: Dict[str, int] = {} + difficulty_counts: Dict[str, int] = {} + bloom_counts: Dict[str, int] = {} + for q in validated: + topic_counts[q.topic] = topic_counts.get(q.topic, 0) + 1 + difficulty_counts[q.difficulty] = difficulty_counts.get(q.difficulty, 0) + 1 + bloom_counts[q.bloomLevel] = bloom_counts.get(q.bloomLevel, 0) + 1 + + metadata: Dict[str, Any] = { + "topicsCovered": topic_counts, + "difficultyBreakdown": difficulty_counts, + "bloomTaxonomyDistribution": bloom_counts, + "questionTypeBreakdown": dict(Counter(q.questionType for q in validated)), + "gradeLevel": request.gradeLevel, + "totalQuestions": len(validated), + "includesGraphQuestions": request.includeGraphs, + "classSectionId": request.classSectionId, + "className": request.className, + "materialId": request.materialId, + "importGroundingEnabled": import_grounding_enabled, + "usedImportedTopics": bool(imported_topic_titles), + "importedMaterialsCount": len(imported_topics_payload.get("materials", [])), + "importedTopicCount": len(imported_topics_payload.get("topics", [])), + "importWarnings": import_warnings, + "topicProvenance": [ + { + "topicId": topic.get("topicId"), + "title": topic.get("title"), + "materialId": topic.get("materialId"), + "sourceFile": topic.get("sourceFile"), + "sectionId": topic.get("sectionId"), + } + for topic in (imported_topics_payload.get("topics") or []) + ][:20], + "supplementalPurpose": ( + "This quiz is designed to supplement classroom instruction, " + "not replace teacher-led learning." + ), + "bloomTaxonomyRationale": ( + "Ensures questions assess different cognitive levels from basic recall " + "to complex analysis, providing comprehensive skill evaluation." + ), + "recommendedTeacherActions": [ + "Review questions before assigning to students", + "Use difficulty breakdown to identify areas needing re-teaching", + "Focus on topics where students score below 60%", + "Use Bloom analysis to ensure higher-order thinking is practiced", + ], + } + + if request.includeGraphs: + metadata["graphQuestionNote"] = ( + "Graph questions use identification format as graphing is " + "challenging for students. Graphs are described in text." + ) + + logger.info(f"Quiz generated: {len(validated)} questions, {total_points} total points") + + return QuizResponse( + questions=validated, + totalPoints=total_points, + metadata=metadata, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Quiz generation error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Quiz generation error: {str(e)}") + + +@app.post("/api/quiz/preview", response_model=QuizResponse) +async def preview_quiz(http_request: Request, request: QuizGenerationRequest): + """ + Generate a 3-question preview quiz for teachers to verify AI question + quality before assigning a full quiz to students. + """ + # Override to produce only 3 questions + request.numQuestions = 3 + return await generate_quiz(http_request, request) + + +@app.post("/api/quiz/generate-async", response_model=AsyncTaskSubmitResponse) +async def generate_quiz_async(http_request: Request, request: QuizGenerationRequest): + if not ENABLE_ASYNC_GENERATION: + raise HTTPException(status_code=404, detail="Async generation is disabled") + + enforce_rate_limit(http_request, "generate_quiz_async", 20, 60) + user = get_current_user(http_request) + if user.role not in TEACHER_OR_ADMIN: + raise HTTPException(status_code=403, detail="Forbidden for this role") + + task_id = _create_async_task( + owner_uid=user.uid, + task_kind="quiz_generation", + payload=request.model_dump(), + ) + + async def _quiz_runner() -> QuizResponse: + _update_async_task( + task_id, + progressPercent=30.0, + progressStage="preparing", + progressMessage="Preparing quiz prompt and topic distribution.", + ) + await asyncio.sleep(0) + _update_async_task( + task_id, + progressPercent=70.0, + progressStage="generating", + progressMessage="Generating quiz questions with the AI model.", + ) + response = await generate_quiz(http_request, request) + _update_async_task( + task_id, + progressPercent=92.0, + progressStage="finalizing", + progressMessage="Validating generated quiz and finalizing output.", + ) + return response + + _start_async_task_in_thread(task_id, _quiz_runner) + + with _async_tasks_lock: + task = dict(_async_tasks.get(task_id, {})) + + return AsyncTaskSubmitResponse( + success=True, + taskId=task_id, + status=str(task.get("status") or "queued"), + taskKind="quiz_generation", + createdAt=str(task.get("createdAt") or _utc_now_iso()), + ) + + +@app.get("/api/tasks/{task_id}", response_model=AsyncTaskStatusResponse) +async def get_async_task_status(http_request: Request, task_id: str): + user = get_current_user(http_request) + + with _async_tasks_lock: + _prune_async_tasks() + task = _async_tasks.get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + owner_uid = str(task.get("ownerUid") or "") + if user.role != "admin" and owner_uid != user.uid: + raise HTTPException(status_code=403, detail="Forbidden for this task") + task_data = dict(task) + + return AsyncTaskStatusResponse( + success=True, + taskId=str(task_data.get("taskId") or task_id), + taskKind=str(task_data.get("taskKind") or "unknown"), + status=str(task_data.get("status") or "queued"), + createdAt=str(task_data.get("createdAt") or _utc_now_iso()), + startedAt=cast(Optional[str], task_data.get("startedAt")), + completedAt=cast(Optional[str], task_data.get("completedAt")), + progressPercent=float(task_data.get("progressPercent") or 0.0), + progressStage=str(task_data.get("progressStage") or "queued"), + progressMessage=cast(Optional[str], task_data.get("progressMessage")), + result=cast(Optional[Dict[str, Any]], task_data.get("result")), + error=task_data.get("error"), + ) + + +@app.get("/api/tasks", response_model=AsyncTaskListResponse) +async def list_async_tasks( + http_request: Request, + limit: int = Query(default=50, ge=1, le=200), + status: Optional[str] = Query(default=None), + include_results: bool = Query(default=False), +): + user = get_current_user(http_request) + status_filter = (status or "").strip().lower() + + with _async_tasks_lock: + _prune_async_tasks() + raw_tasks = list(_async_tasks.values()) + + filtered: List[Dict[str, Any]] = [] + for task in raw_tasks: + owner_uid = str(task.get("ownerUid") or "") + if user.role != "admin" and owner_uid != user.uid: + continue + task_status = str(task.get("status") or "queued").strip().lower() + if status_filter and task_status != status_filter: + continue + filtered.append(dict(task)) + + filtered.sort(key=lambda item: str(item.get("createdAt") or ""), reverse=True) + selected = filtered[:limit] + + task_items: List[AsyncTaskStatusResponse] = [] + for item in selected: + task_items.append( + AsyncTaskStatusResponse( + success=True, + taskId=str(item.get("taskId") or ""), + taskKind=str(item.get("taskKind") or "unknown"), + status=str(item.get("status") or "queued"), + createdAt=str(item.get("createdAt") or _utc_now_iso()), + startedAt=cast(Optional[str], item.get("startedAt")), + completedAt=cast(Optional[str], item.get("completedAt")), + progressPercent=float(item.get("progressPercent") or 0.0), + progressStage=str(item.get("progressStage") or "queued"), + progressMessage=cast(Optional[str], item.get("progressMessage")), + result=cast(Optional[Dict[str, Any]], item.get("result") if include_results else None), + error=item.get("error"), + ) + ) + + return AsyncTaskListResponse(success=True, count=len(task_items), tasks=task_items) + + +@app.post("/api/tasks/{task_id}/cancel", response_model=AsyncTaskCancelResponse) +async def cancel_async_task(http_request: Request, task_id: str): + user = get_current_user(http_request) + + with _async_tasks_lock: + _prune_async_tasks() + task = _async_tasks.get(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + + owner_uid = str(task.get("ownerUid") or "") + if user.role != "admin" and owner_uid != user.uid: + raise HTTPException(status_code=403, detail="Forbidden for this task") + + current_status = str(task.get("status") or "queued").strip().lower() + if current_status in {"completed", "failed", "cancelled"}: + return AsyncTaskCancelResponse( + success=False, + taskId=task_id, + status=current_status, + message="Task is already in a terminal state.", + ) + + task["cancelRequested"] = True + if current_status == "queued": + task["status"] = "cancelled" + task["completedAt"] = _utc_now_iso() + task["progressPercent"] = 100.0 + task["progressStage"] = "cancelled" + task["progressMessage"] = "Task was cancelled before execution." + task["error"] = {"message": "Task cancelled before execution."} + else: + task["status"] = "cancelling" + task["progressStage"] = "cancelling" + task["progressMessage"] = "Cancellation requested. Waiting for task to stop." + + updated_status = str(task.get("status") or "queued") + + return AsyncTaskCancelResponse( + success=True, + taskId=task_id, + status=updated_status, + message="Cancellation request accepted.", + ) + + +@app.get("/api/ops/inference-metrics", response_model=InferenceMetricsResponse) +async def get_inference_metrics(http_request: Request): + user = get_current_user(http_request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + client = get_inference_client() + metrics_snapshot = client.snapshot_metrics() + metrics_snapshot["pro_enabled"] = bool(getattr(client, "pro_enabled", False)) + return InferenceMetricsResponse(success=True, metrics=metrics_snapshot) + + +@app.get("/api/hf/monitoring", response_model=HFMonitoringDataResponse) +async def get_hf_monitoring(http_request: Request): + """ + Aggregates DeepSeek AI status, model config, and latency probe. + Returns distilled data safe for frontend consumption. + + Requires admin authentication. + """ + user = get_current_user(http_request) + if user.role != "admin": + raise HTTPException(status_code=403, detail="Forbidden for this role") + + _ensure_deepseek_available() + + try: + generation_model_id = get_model_for_task("chat") + except Exception: + generation_model_id = CHAT_MODEL + + embedding_model_id = os.getenv("EMBEDDING_MODEL", "BAAI/bge-small-en-v1.5") + + runtime_config = get_current_runtime_config() + + task_resolved: dict[str, str] = {} + for task in [ + "chat", "verify_solution", "lesson_generation", "quiz_generation", + "learning_path", "daily_insight", "risk_classification", "risk_narrative", + "rag_lesson", "rag_problem", "rag_analysis_context", + ]: + try: + task_resolved[task] = get_model_for_task(task) + except Exception: + task_resolved[task] = generation_model_id + + result: Dict[str, Any] = { + "modelId": generation_model_id, + "modelStatus": "Operational", + "avgResponseTimeMs": 0, + "embeddingModelId": embedding_model_id, + "embeddingModelStatus": "Operational", + "inferenceBalance": 0.0, + "totalPeriodCost": 0.0, + "hubApiCallsUsed": 0, + "hubApiCallsLimit": 2500, + "zeroGpuMinutesUsed": 0, + "zeroGpuMinutesLimit": 25, + "publicStorageUsedTB": 0.0, + "publicStorageLimitTB": 11.2, + "lastChecked": datetime.now(timezone.utc).isoformat(), + "periodStart": "", + "periodEnd": "", + "activeProfile": runtime_config.get("profile") or os.getenv("MODEL_PROFILE", "dev"), + "runtimeOverridesActive": len(runtime_config.get("overrides", {})) > 0, + "resolvedModels": task_resolved, + "provider": "deepseek", + "apiBaseUrl": os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com"), + } + + try: + client = get_deepseek_client() + latency_start = time.time() + probe_response = client.chat.completions.create( + model=str(CHAT_MODEL), + messages=[{"role": "user", "content": "Hi"}], + max_tokens=1, + temperature=0.0, + ) + latency_ms = int((time.time() - latency_start) * 1000) + result["avgResponseTimeMs"] = latency_ms + result["modelStatus"] = "Operational" + except Exception as e: + logger.warning(f"DeepSeek latency probe failed: {e}") + result["modelStatus"] = "Degraded" + + return HFMonitoringDataResponse(success=True, data=result) + + +@app.get("/api/quiz/topics") +async def get_quiz_topics(response: Response, gradeLevel: Optional[str] = None): + """ + Return structured list of SHS math topics organised by grade level. + Only Grade 11 and Grade 12 are supported. + If gradeLevel is provided, return topics for that grade only. + """ + response.headers["Cache-Control"] = "public, max-age=300, stale-while-revalidate=900" + + if gradeLevel: + key = _resolve_grade_level_key(gradeLevel) + if key: + return {"gradeLevel": key, "topics": MATH_TOPICS_BY_GRADE[key]} + raise HTTPException( + status_code=404, + detail=f"Grade level '{gradeLevel}' not found. Available: {list(MATH_TOPICS_BY_GRADE.keys())}", + ) + + # Return all SHS topics organized by grade + return { + "gradeLevels": list(MATH_TOPICS_BY_GRADE.keys()), + "allTopics": MATH_TOPICS_BY_GRADE, + } + + +# ─── Student Competency Assessment ──────────────────────────── + + +@app.post("/api/quiz/student-competency", response_model=StudentCompetencyResponse) +async def student_competency(request: StudentCompetencyRequest): + """ + Assess a student's competency per topic based on their quiz history. + Returns efficiency scores, competency levels, and recommendations. + """ + try: + history = request.quizHistory or [] + + if not history: + # No history — return empty competency with recommendation to start + return StudentCompetencyResponse( + studentId=request.studentId, + competencies=[], + recommendedTopics=["Start with foundational topics to build a learning profile"], + excludeTopics=[], + ) + + # Aggregate scores per topic + topic_data: Dict[str, List[Dict[str, Any]]] = {} + for entry in history: + topic = _canonicalize_topic_label(str(entry.get("topic", "Unknown"))) + if topic not in topic_data: + topic_data[topic] = [] + topic_data[topic].append(entry) + + # Compute competency per topic + competencies: List[TopicCompetency] = [] + recommended: List[str] = [] + exclude: List[str] = [] + + for topic, entries in topic_data.items(): + scores = [e.get("score", 0) / max(e.get("total", 1), 1) * 100 for e in entries] + avg_score = sum(scores) / len(scores) if scores else 0 + + # Factor in time efficiency (faster with correct answers = more efficient) + time_factors = [] + for e in entries: + if e.get("timeTaken") and e.get("total"): + time_per_q = e["timeTaken"] / e["total"] + # Normalise: < 30s per question = efficient, > 120s = slow + efficiency = max(0, min(100, 100 - (time_per_q - 30) * (100 / 90))) + time_factors.append(efficiency) + + time_efficiency = sum(time_factors) / len(time_factors) if time_factors else 50 + efficiency_score = round(avg_score * 0.7 + time_efficiency * 0.3, 1) + + if efficiency_score >= 85: + level = "advanced" + perspective = f"Student demonstrates strong mastery of {topic}. Consistently scores well with efficient problem-solving." + exclude.append(topic) + elif efficiency_score >= 65: + level = "proficient" + perspective = f"Student has solid understanding of {topic} but may benefit from challenging practice problems." + elif efficiency_score >= 40: + level = "developing" + perspective = f"Student shows foundational knowledge of {topic} but needs more practice to build fluency." + recommended.append(topic) + else: + level = "beginner" + perspective = f"Student is still building understanding of {topic}. Recommend focused review and guided practice." + recommended.insert(0, topic) # High-priority + + competencies.append(TopicCompetency( + topic=topic, + efficiencyScore=efficiency_score, + competencyLevel=level, + perspective=perspective, + )) + + # If the AI is available, enhance perspectives + if competencies: + try: + summary = ", ".join( + f"{c.topic}: {c.competencyLevel} ({c.efficiencyScore}%)" + for c in competencies + ) + ai_prompt = f"""Based on this student competency profile, provide a brief (2-3 sentence) overall assessment: +{summary} + +Focus on actionable recommendations. Be encouraging yet honest.""" + + overall_perspective = await call_hf_chat_async( + messages=[ + {"role": "system", "content": "You are an educational assessment expert. Be concise and supportive."}, + {"role": "user", "content": ai_prompt}, + ], + max_tokens=200, + temperature=0.3, + ) + if overall_perspective: + # Add to recommended as a note + recommended.append(f"AI Insight: {overall_perspective.strip()}") + except Exception as e: + logger.warning(f"AI competency enhancement failed: {e}") + + competencies.sort(key=lambda c: c.efficiencyScore) + + return StudentCompetencyResponse( + studentId=request.studentId, + competencies=competencies, + recommendedTopics=recommended, + excludeTopics=exclude, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Student competency error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Competency assessment error: {str(e)}") + + +# ─── Calculator / Symbolic Math ─────────────────────────────── + +# Allowed names for safe expression evaluation via SymPy +_SAFE_SYMPY_NAMES: Optional[Dict[str, Any]] = None + + +def _get_sympy_safe_dict() -> Dict[str, Any]: + """Lazily build allowlist of SymPy names for safe eval.""" + global _SAFE_SYMPY_NAMES + if _SAFE_SYMPY_NAMES is not None: + return _SAFE_SYMPY_NAMES + + import sympy # type: ignore[import-untyped] + + _SAFE_SYMPY_NAMES = { + # Symbols + "x": sympy.Symbol("x"), + "y": sympy.Symbol("y"), + "z": sympy.Symbol("z"), + "t": sympy.Symbol("t"), + "n": sympy.Symbol("n"), + # Constants + "pi": sympy.pi, + "e": sympy.E, + "E": sympy.E, + "I": sympy.I, + "oo": sympy.oo, + "inf": sympy.oo, + # Functions + "sin": sympy.sin, + "cos": sympy.cos, + "tan": sympy.tan, + "asin": sympy.asin, + "acos": sympy.acos, + "atan": sympy.atan, + "sinh": sympy.sinh, + "cosh": sympy.cosh, + "tanh": sympy.tanh, + "log": sympy.log, + "ln": sympy.log, + "exp": sympy.exp, + "sqrt": sympy.sqrt, + "Abs": sympy.Abs, + "abs": sympy.Abs, + "factorial": sympy.factorial, + "binomial": sympy.binomial, + "ceiling": sympy.ceiling, + "floor": sympy.floor, + # Operations + "diff": sympy.diff, + "integrate": sympy.integrate, + "limit": sympy.limit, + "solve": sympy.solve, + "simplify": sympy.simplify, + "expand": sympy.expand, + "factor": sympy.factor, + "Rational": sympy.Rational, + "Matrix": sympy.Matrix, + "Sum": sympy.Sum, + "Product": sympy.Product, + "Derivative": sympy.Derivative, + "Integral": sympy.Integral, + "Limit": sympy.Limit, + } + return _SAFE_SYMPY_NAMES + + +_DANGEROUS_PATTERNS = re.compile( + r"(__\w+__|import\s|exec\s*\(|eval\s*\(|open\s*\(|os\.|sys\.|subprocess|shutil|__builtins__|globals|locals|compile|getattr|setattr|delattr)", + re.IGNORECASE, +) + + +@app.post("/api/calculator/evaluate", response_model=CalculatorResponse) +async def calculator_evaluate(request: CalculatorRequest): + """ + Evaluate a mathematical expression symbolically using SymPy. + Supports arithmetic, algebra, trigonometry, and calculus. + """ + try: + import sympy # type: ignore[import-untyped] + + expr_str = request.expression.strip() + + # Safety validation + if _DANGEROUS_PATTERNS.search(expr_str): + raise HTTPException( + status_code=400, + detail="Expression contains disallowed patterns. Only mathematical expressions are permitted.", + ) + if len(expr_str) > 500: + raise HTTPException(status_code=400, detail="Expression too long (max 500 characters).") + + safe_dict = _get_sympy_safe_dict() + steps: List[str] = [f"Input expression: {expr_str}"] + + # Parse expression + try: + parsed = sympy.sympify(expr_str, locals=safe_dict) + steps.append(f"Parsed as: {parsed}") + except Exception as parse_err: + raise HTTPException( + status_code=400, + detail=f"Could not parse expression: {str(parse_err)}", + ) + + # Simplify + simplified = sympy.simplify(parsed) + if simplified != parsed: + steps.append(f"Simplified: {simplified}") + + # Try numeric evaluation + try: + numeric = float(simplified.evalf()) + if numeric == int(numeric): + result_str = str(int(numeric)) + else: + result_str = str(round(numeric, 10)) + steps.append(f"Numerical result: {result_str}") + except Exception: + result_str = str(simplified) + steps.append(f"Symbolic result: {result_str}") + + # LaTeX representation + try: + latex_str = sympy.latex(simplified) + except Exception: + latex_str = None + + return CalculatorResponse( + expression=expr_str, + result=result_str, + steps=steps, + simplified=str(simplified) if simplified != parsed else None, + latex=latex_str, + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Calculator error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Calculator error: {str(e)}") + + +# ─── ML-Powered Student Analytics Endpoints ────────────────── + + +@app.post("/api/student/competency-analysis", response_model=CompetencyAnalysisResponse) +async def student_competency_analysis(request: CompetencyAnalysisRequest): + """ + Analyse student competency per topic using IRT (Item Response Theory). + Calculates efficiency scores, mastery percentages, learning velocity, + and theta (ability) estimates. + """ + try: + logger.info(f"Competency analysis requested for student {request.studentId}") + + # Fetch quiz history from Firestore + quiz_history = await fetch_student_quiz_history(request.studentId) + + result = await compute_competency_analysis( + student_id=request.studentId, + quiz_history=quiz_history, + topic_filter=request.topicId, + ) + + # Store results if successful + if result.status == "success": + await store_competency_analysis( + request.studentId, + { + "analyses": [a.dict() for a in result.analyses], + "overallCompetency": result.overallCompetency, + "thetaEstimate": result.thetaEstimate, + }, + ) + + return result + + except Exception as e: + logger.error(f"Competency analysis error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Competency analysis error: {str(e)}") + + +@app.post("/api/risk/train-model", response_model=RiskTrainResponse) +async def train_risk_classification_model(request: RiskTrainRequest): + """ + Train a supervised ML model (XGBoost/Random Forest) for student risk prediction. + Admin-only endpoint. Collects historical data from Firestore, trains the model, + and saves it to disk. + """ + try: + logger.info(f"Risk model training requested (forceRetrain={request.forceRetrain})") + result = await train_risk_model(force_retrain=request.forceRetrain) + return result + except Exception as e: + logger.error(f"Risk model training error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Model training error: {str(e)}") + + +@app.post("/api/predict-risk/enhanced", response_model=EnhancedRiskPrediction) +async def predict_risk_ml(data: EnhancedRiskRequest): + """ + Enhanced student risk prediction using trained ML model with SHAP explanations. + Falls back to rule-based heuristics if no trained model is available. + Returns risk probabilities for all classes and top contributing factors. + """ + try: + logger.info(f"Enhanced risk prediction for student {data.studentId}") + result = await predict_risk_enhanced(data) + if ENABLE_LLM_RISK_RECOMMENDATIONS: + try: + llm_recommendations = await _generate_risk_recommendations_llm(data, result) + if llm_recommendations: + result.recommendations = llm_recommendations + result.modelUsed = "ml_model+llm_narrative" + except Exception as llm_exc: + logger.warning(f"Risk recommendation generation via LLM failed: {llm_exc}") + return result + except Exception as e: + logger.error(f"Enhanced risk prediction error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Risk prediction error: {str(e)}") + + +@app.post("/api/quiz/calibrate-difficulty", response_model=CalibrateDifficultyResponse) +async def calibrate_quiz_difficulty(request: CalibrateDifficultyRequest): + """ + Calculate IRT difficulty parameters for a question based on student responses. + Uses 3-Parameter Logistic model to estimate difficulty (b), discrimination (a), + and guessing (c) parameters. + """ + try: + logger.info(f"Calibrating difficulty for question {request.questionId}") + result = await calibrate_question_difficulty(request) + return result + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Difficulty calibration error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Calibration error: {str(e)}") + + +@app.post("/api/quiz/adaptive-select") +async def adaptive_quiz_selection(request: AdaptiveQuizSelectRequest): + """ + Select questions adaptively based on student ability level using IRT. + Adjusts difficulty distribution to target ~70-75% success rate. + Uses student competency data to personalize quiz difficulty. + """ + try: + logger.info(f"Adaptive quiz selection for student {request.studentId}, topic {request.topicId}") + result = await select_adaptive_quiz(request) + return result + except Exception as e: + logger.error(f"Adaptive quiz selection error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Adaptive selection error: {str(e)}") + + +@app.post("/api/learning/recommend-topics", response_model=TopicRecommendationResponse) +async def recommend_learning_topics(request: TopicRecommendationRequest): + """ + Recommend topics for a student based on competency gaps, prerequisites, + recency of practice, and peer performance patterns. + Returns ranked list with reasoning and estimated time to mastery. + """ + try: + logger.info(f"Topic recommendation for student {request.studentId}") + result = await recommend_topics(request) + return result + except Exception as e: + logger.error(f"Topic recommendation error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Recommendation error: {str(e)}") + + +@app.get("/api/analytics/student-summary", response_model=StudentSummaryResponse) +async def student_analytics_summary(request: Request, studentId: str = Query(..., description="Firebase user ID")): + """ + Aggregate all ML-powered metrics for a student: + competency distribution, risk assessment, recommendations, + learning velocity trends, efficiency scores, predicted performance, + and engagement pattern analysis. + """ + try: + require_student_self_or_staff(request, studentId) + logger.info(f"Student summary requested for {studentId}") + result = await get_student_summary(studentId) + return result + except Exception as e: + logger.error(f"Student summary error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Analytics error: {str(e)}") + + +@app.post("/api/analytics/class-insights", response_model=ClassInsightsResponse) +async def class_analytics_insights(request: ClassInsightsRequest): + """ + Aggregate class-wide ML analytics for teacher dashboards. + Includes risk distribution, common weak topics, learning velocity, + engagement patterns, and intervention recommendations. + """ + try: + logger.info(f"Class insights requested by teacher {request.teacherId}") + result = await get_class_insights(request) + return result + except Exception as e: + logger.error(f"Class insights error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Class insights error: {str(e)}") + + +@app.post("/api/analytics/refresh-cache", response_model=RefreshCacheResponse) +async def refresh_analytics_cache(): + """ + Force clear and refresh all ML analytics caches. + Use when student data has been updated and fresh analysis is needed. + """ + try: + result = refresh_all_caches() + logger.info("Analytics caches refreshed") + return result + except Exception as e: + logger.error(f"Cache refresh error: {e}") + raise HTTPException(status_code=500, detail=f"Cache refresh error: {str(e)}") + + +@app.post("/api/dev/generate-mock-data") +async def generate_mock_data(request: MockDataRequest): + """ + Generate realistic mock student data for testing ML features. + Development/testing endpoint only. + Generates students with varied archetypes: perfect, struggling, + inconsistent, improving, declining, and average performers. + """ + try: + logger.info(f"Generating mock data: {request.numStudents} students, {request.numQuizzes} quizzes") + data = generate_mock_student_data( + num_students=request.numStudents, + num_quizzes=request.numQuizzes, + seed=request.seed, + ) + return data + except Exception as e: + logger.error(f"Mock data generation error: {e}") + raise HTTPException(status_code=500, detail=f"Mock data error: {str(e)}") + + +TESTING_RESET_BATCH_SIZE = 400 + + +def _testing_reset_timestamp_value() -> Any: + if FIRESTORE_SERVER_TIMESTAMP is not None: + return FIRESTORE_SERVER_TIMESTAMP + return datetime.now(timezone.utc) + + +def _testing_reset_delete_by_field( + client: Any, + collection_name: str, + field_name: str, + field_value: str, +) -> int: + normalized_value = str(field_value or "").strip() + if not normalized_value: + return 0 + + docs = list( + client.collection(collection_name).where(field_name, "==", normalized_value).stream() + ) + if not docs: + return 0 + + deleted_docs = 0 + batch = client.batch() + pending_writes = 0 + + for doc_snapshot in docs: + batch.delete(doc_snapshot.reference) + deleted_docs += 1 + pending_writes += 1 + + if pending_writes >= TESTING_RESET_BATCH_SIZE: + batch.commit() + batch = client.batch() + pending_writes = 0 + + if pending_writes > 0: + batch.commit() + + return deleted_docs + + +def _testing_reset_try_delete_by_field( + client: Any, + collection_name: str, + field_name: str, + field_value: str, +) -> int: + try: + return _testing_reset_delete_by_field( + client, + collection_name, + field_name, + field_value, + ) + except Exception as err: + logger.warning( + "Testing reset skipped delete for %s (%s == %s): %s", + collection_name, + field_name, + field_value, + err, + ) + return 0 + + +def _testing_reset_try_delete_doc(doc_ref: Any, label: str) -> int: + try: + doc_ref.delete() + return 1 + except Exception as err: + logger.warning("Testing reset skipped delete for %s: %s", label, err) + return 0 + + +def _testing_reset_try_set_doc(doc_ref: Any, payload: Dict[str, Any], label: str, merge: bool = False) -> int: + try: + if merge: + doc_ref.set(payload, merge=True) + else: + doc_ref.set(payload) + return 1 + except Exception as err: + logger.warning("Testing reset skipped update for %s: %s", label, err) + return 0 + + +def _reset_student_testing_data_admin( + client: Any, + uid: str, + lrn: Optional[str], +) -> Tuple[int, int]: + deleted_docs = 0 + updated_docs = 0 + effective_lrn = (lrn or uid or "").strip() or uid + timestamp_value = _testing_reset_timestamp_value() + + updated_docs += _testing_reset_try_set_doc( + client.collection("progress").document(uid), + { + "userId": uid, + "subjects": {}, + "lessons": {}, + "quizAttempts": [], + "totalLessonsCompleted": 0, + "totalQuizzesCompleted": 0, + "averageScore": 0, + "updatedAt": timestamp_value, + }, + f"progress/{uid}", + merge=False, + ) + + updated_docs += _testing_reset_try_set_doc( + client.collection("users").document(uid), + { + "level": 1, + "currentXP": 0, + "totalXP": 0, + "streak": 0, + "streakHistory": [], + "atRiskSubjects": [], + "hasTakenDiagnostic": False, + "iarAssessmentState": "not_started", + "learningPathState": "unlocked", + "remediationState": "not_required", + "subjectBadges": {}, + "riskClassifications": {}, + "overallRisk": "Low", + "updatedAt": timestamp_value, + }, + f"users/{uid}", + merge=True, + ) + + deleted_docs += _testing_reset_try_delete_by_field(client, "notifications", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatSessions", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatMessages", "userId", uid) + + if effective_lrn != uid: + deleted_docs += _testing_reset_try_delete_by_field(client, "notifications", "userId", effective_lrn) + + updated_docs += _testing_reset_try_set_doc( + client.collection("achievements").document(uid), + { + "userId": uid, + "achievements": [], + "totalAchievements": 0, + "updatedAt": timestamp_value, + }, + f"achievements/{uid}", + merge=True, + ) + + return deleted_docs, updated_docs + + +def _reset_teacher_testing_data_admin(client: Any, uid: str) -> Tuple[int, int]: + deleted_docs = 0 + updated_docs = 0 + timestamp_value = _testing_reset_timestamp_value() + + classroom_docs = list(client.collection("classrooms").where("teacherId", "==", uid).stream()) + classroom_ids = [doc_snapshot.id for doc_snapshot in classroom_docs] + + deleted_docs += _testing_reset_try_delete_by_field(client, "notifications", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatSessions", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatMessages", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "announcements", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "classSectionOwnership", "ownerTeacherId", uid) + + deleted_docs += _testing_reset_try_delete_by_field(client, "managedStudents", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "classrooms", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "normalizedClassRecords", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "classRecordImports", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "courseMaterials", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "riskRefreshEvents", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "riskRefreshJobs", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "importGroundedFeedbackEvents", "teacherId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "accessAuditLogs", "actorUid", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "accessAuditLogs", "teacherId", uid) + + for classroom_id in classroom_ids: + deleted_docs += _testing_reset_try_delete_by_field(client, "managedStudents", "classroomId", classroom_id) + deleted_docs += _testing_reset_try_delete_by_field(client, "activities", "classroomId", classroom_id) + deleted_docs += _testing_reset_try_delete_by_field(client, "announcements", "classroomId", classroom_id) + + deleted_docs += _testing_reset_try_delete_doc( + client.collection("riskRefreshStats").document(uid), + f"riskRefreshStats/{uid}", + ) + + updated_docs += _testing_reset_try_set_doc( + client.collection("users").document(uid), + { + "testingResetAt": timestamp_value, + "updatedAt": timestamp_value, + }, + f"users/{uid}", + merge=True, + ) + + return deleted_docs, updated_docs + + +def _reset_admin_testing_data_admin(client: Any, uid: str) -> Tuple[int, int]: + deleted_docs = 0 + updated_docs = 0 + timestamp_value = _testing_reset_timestamp_value() + + deleted_docs += _testing_reset_try_delete_by_field(client, "notifications", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatSessions", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "chatMessages", "userId", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "curriculumContent", "updatedBy", uid) + deleted_docs += _testing_reset_try_delete_by_field(client, "curriculumContent", "deletedBy", uid) + + updated_docs += _testing_reset_try_set_doc( + client.collection("users").document(uid), + { + "testingResetAt": timestamp_value, + "updatedAt": timestamp_value, + }, + f"users/{uid}", + merge=True, + ) + + return deleted_docs, updated_docs + + +@app.post("/api/testing/reset-data", response_model=TestingResetResponse) +async def reset_testing_data(request: Request, payload: TestingResetRequest): + user = get_current_user(request) + + requested_role = (payload.role or "").strip().lower() + if requested_role and requested_role not in VALID_ROLES: + raise HTTPException(status_code=400, detail="Invalid role for testing reset.") + if requested_role and requested_role != user.role: + raise HTTPException(status_code=403, detail="Reset role does not match authenticated user role.") + + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + try: + client = firebase_firestore.client() + if user.role == "student": + deleted_docs, updated_docs = _reset_student_testing_data_admin(client, user.uid, payload.lrn) + elif user.role == "teacher": + deleted_docs, updated_docs = _reset_teacher_testing_data_admin(client, user.uid) + else: + deleted_docs, updated_docs = _reset_admin_testing_data_admin(client, user.uid) + except Exception as firestore_err: + if _is_adc_missing_error(cast(Exception, firestore_err)): + raise HTTPException( + status_code=503, + detail=( + "Firestore ADC is not configured. Set FIREBASE_SERVICE_ACCOUNT_JSON " + "or FIREBASE_SERVICE_ACCOUNT_FILE, or set GOOGLE_APPLICATION_CREDENTIALS." + ), + ) + logger.error( + "Testing reset failed for uid=%s role=%s: %s", + user.uid, + user.role, + firestore_err, + ) + raise HTTPException(status_code=500, detail="Failed to reset testing data.") + + summary = ( + f"{user.role} reset complete: {deleted_docs} records deleted, " + f"{updated_docs} records reset." + ) + return TestingResetResponse( + role=user.role, + deletedDocs=deleted_docs, + updatedDocs=updated_docs, + summary=summary, + ) + + +@app.get("/api/analytics/config") +async def get_analytics_config(): + """Return current ML analytics configuration parameters.""" + return { + "riskModelPath": RISK_MODEL_PATH, + "riskModelExists": os.path.exists(RISK_MODEL_PATH), + "competencyThresholds": COMPETENCY_THRESHOLDS, + "minQuizAttemptsForCompetency": MIN_QUIZ_ATTEMPTS_FOR_COMPETENCY, + "cacheTTLSeconds": 3600, + "topicPrerequisites": fetch_topic_dependencies(), + } + + +@app.get("/api/analytics/imported-class-overview") +async def get_imported_class_overview( + request: Request, + classSectionId: Optional[str] = Query(default=None), + limit: int = Query(default=3000, ge=1, le=5000), +): + """Return teacher-facing class/student aggregates derived from imported normalized records.""" + try: + user = get_current_user(request) + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + normalized_class_section_id = (classSectionId or "").strip().lower() or None + client = firebase_firestore.client() + try: + query = client.collection("normalizedClassRecords").where("teacherId", "==", user.uid) + docs = list(query.limit(limit).stream()) + except Exception as firestore_err: + if _is_adc_missing_error(cast(Exception, firestore_err)): + raise HTTPException( + status_code=503, + detail=( + "Firestore ADC is not configured. Set FIREBASE_SERVICE_ACCOUNT_JSON " + "or FIREBASE_SERVICE_ACCOUNT_FILE, or set GOOGLE_APPLICATION_CREDENTIALS." + ), + ) + raise + + class_agg: Dict[str, Dict[str, Any]] = {} + student_agg: Dict[str, Dict[str, Any]] = {} + + for doc in docs: + row = doc.to_dict() or {} + resolved_class_section_id, resolved_class_name, grade, section = _resolve_import_class_context( + class_section_id=str(row.get("classSectionId") or "").strip() or None, + class_name=str(row.get("className") or "").strip() or None, + ) + row_class_metadata = _build_class_metadata( + class_section_id=resolved_class_section_id, + class_name=resolved_class_name, + grade=grade, + section=section, + owner_teacher_id=user.uid, + owner_teacher_name=user.email, + adviser_teacher_id=user.uid, + adviser_teacher_name=user.email, + manager_id=user.uid, + manager_name=user.email, + ) + + if normalized_class_section_id and resolved_class_section_id != normalized_class_section_id: + continue + + student_identity = ( + str(row.get("studentId") or "").strip() + or str(row.get("lrn") or "").strip() + or str(row.get("email") or "").strip().lower() + or re.sub(r"\s+", "_", str(row.get("name") or "").strip().lower()) + ) + if not student_identity: + continue + + student_key = f"{resolved_class_section_id}|{student_identity}" + avg_quiz = float(row.get("avgQuizScore") or 0.0) + attendance = float(row.get("attendance") or 0.0) + engagement = float(row.get("engagementScore") or 0.0) + completion = float(row.get("assignmentCompletion") or 0.0) + weakest_topic = _pick_weakest_topic(row.get("unknownFields") or {}) + has_topic_signal = weakest_topic not in {"", "Foundational Skills"} + inference = _infer_student_state( + avg_quiz=avg_quiz, + attendance=attendance, + engagement=engagement, + defaulted_metrics=set(), + has_topic_signal=has_topic_signal, + ) + risk_level = str(inference["riskLevel"]) + + existing_student = student_agg.get(student_key) + if not existing_student: + student_agg[student_key] = { + "id": hashlib.sha1(f"{user.uid}|{student_key}".encode("utf-8")).hexdigest()[:36], + "lrn": str(row.get("lrn") or "").strip() or None, + "name": str(row.get("name") or "Imported Student").strip() or "Imported Student", + "email": str(row.get("email") or "").strip(), + "classSectionId": resolved_class_section_id, + "className": resolved_class_name, + "grade": grade, + "gradeLevel": row_class_metadata.get("gradeLevel"), + "classification": row_class_metadata.get("classification"), + "strand": row_class_metadata.get("strand"), + "section": section, + "managerId": row_class_metadata.get("managerId"), + "managerName": row_class_metadata.get("managerName"), + "classMetadata": row_class_metadata, + "scores": [avg_quiz], + "attendanceValues": [attendance], + "engagementValues": [engagement], + "completionValues": [completion], + "riskLevel": risk_level, + "weakestTopic": weakest_topic, + "inferredState": { + "state": inference["state"], + "confidence": inference["confidence"], + "signals": inference["signals"], + "explanation": inference["explanation"], + "fallbackUsed": inference["fallbackUsed"], + }, + } + else: + existing_student["scores"].append(avg_quiz) + existing_student["attendanceValues"].append(attendance) + existing_student["engagementValues"].append(engagement) + existing_student["completionValues"].append(completion) + if existing_student.get("riskLevel") != "High" and risk_level == "High": + existing_student["riskLevel"] = "High" + elif existing_student.get("riskLevel") == "Low" and risk_level == "Medium": + existing_student["riskLevel"] = "Medium" + if existing_student.get("weakestTopic") in {"", "Foundational Skills"} and weakest_topic: + existing_student["weakestTopic"] = weakest_topic + if float(inference["confidence"]) < float(existing_student.get("inferredState", {}).get("confidence") or 1.0): + existing_student["inferredState"] = { + "state": inference["state"], + "confidence": inference["confidence"], + "signals": inference["signals"], + "explanation": inference["explanation"], + "fallbackUsed": inference["fallbackUsed"], + } + + existing_class = class_agg.get(resolved_class_section_id) + if not existing_class: + class_agg[resolved_class_section_id] = { + "id": resolved_class_section_id, + "name": resolved_class_name, + "classSectionId": resolved_class_section_id, + "grade": grade, + "gradeLevel": row_class_metadata.get("gradeLevel"), + "classification": row_class_metadata.get("classification"), + "strand": row_class_metadata.get("strand"), + "section": section, + "managerId": row_class_metadata.get("managerId"), + "managerName": row_class_metadata.get("managerName"), + "classMetadata": row_class_metadata, + "schedule": "Mon-Fri", + "students": set(), + "scoreValues": [], + "atRiskStudents": set(), + } + + class_agg[resolved_class_section_id]["students"].add(student_key) + class_agg[resolved_class_section_id]["scoreValues"].append(avg_quiz) + if risk_level == "High": + class_agg[resolved_class_section_id]["atRiskStudents"].add(student_key) + + students_payload: List[Dict[str, Any]] = [] + for student in student_agg.values(): + scores = student.pop("scores", []) + attendance_values = student.pop("attendanceValues", []) + engagement_values = student.pop("engagementValues", []) + completion_values = student.pop("completionValues", []) + students_payload.append( + { + **student, + "avgQuizScore": round(float(sum(scores) / max(len(scores), 1)), 1), + "attendance": round(float(sum(attendance_values) / max(len(attendance_values), 1)), 1), + "engagementScore": round(float(sum(engagement_values) / max(len(engagement_values), 1)), 1), + "assignmentCompletion": round(float(sum(completion_values) / max(len(completion_values), 1)), 1), + "stateConfidence": float((student.get("inferredState") or {}).get("confidence") or 0.0), + "stateSignals": list((student.get("inferredState") or {}).get("signals") or []), + } + ) + + classrooms_payload: List[Dict[str, Any]] = [] + for classroom in class_agg.values(): + score_values = classroom.get("scoreValues") or [] + classrooms_payload.append( + { + "id": classroom["id"], + "name": classroom["name"], + "classSectionId": classroom["classSectionId"], + "grade": classroom.get("grade"), + "gradeLevel": classroom.get("gradeLevel"), + "classification": classroom.get("classification"), + "strand": classroom.get("strand"), + "section": classroom.get("section"), + "managerId": classroom.get("managerId"), + "managerName": classroom.get("managerName"), + "classMetadata": classroom.get("classMetadata"), + "schedule": classroom["schedule"], + "studentCount": len(classroom.get("students") or set()), + "avgScore": round(float(sum(score_values) / max(len(score_values), 1)), 1) if score_values else 0.0, + "atRiskCount": len(classroom.get("atRiskStudents") or set()), + } + ) + + classrooms_payload.sort(key=lambda item: item.get("name") or "") + students_payload.sort(key=lambda item: item.get("name") or "") + + warnings: List[str] = [] + if len(docs) >= limit: + warnings.append("Imported class overview reached the maximum row scan limit; results may be partial.") + + inferred_count = sum(1 for item in students_payload if item.get("inferredState")) + inferred_coverage_pct = round((float(inferred_count) / float(max(len(students_payload), 1))) * 100.0, 1) + + return { + "success": True, + "classSectionId": normalized_class_section_id, + "classrooms": classrooms_payload, + "students": students_payload, + "inferredStateCoverage": { + "inferredRows": inferred_count, + "studentRows": len(students_payload), + "coveragePct": inferred_coverage_pct, + }, + "warnings": warnings, + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Imported class overview error: {e}") + raise HTTPException(status_code=500, detail=f"Imported class overview error: {str(e)}") + + +# ─── Topic Mastery Analytics ────────────────────────────────── + +# SHS topic data for fallback/mock generation +_SHS_TOPICS = { + "gen-math": { + "name": "General Mathematics", + "topics": [ + ("Functions and Relations", "Functions and Their Graphs"), + ("Evaluating Functions", "Functions and Their Graphs"), + ("Operations on Functions", "Functions and Their Graphs"), + ("Composite Functions", "Functions and Their Graphs"), + ("Inverse Functions", "Functions and Their Graphs"), + ("Rational Functions", "Functions and Their Graphs"), + ("Exponential Functions", "Functions and Their Graphs"), + ("Logarithmic Functions", "Functions and Their Graphs"), + ("Simple Interest", "Business Mathematics"), + ("Compound Interest", "Business Mathematics"), + ("Annuities", "Business Mathematics"), + ("Loans and Amortization", "Business Mathematics"), + ("Stocks and Bonds", "Business Mathematics"), + ("Propositions and Connectives", "Logic"), + ("Truth Tables", "Logic"), + ("Logical Equivalence", "Logic"), + ("Valid Arguments and Fallacies", "Logic"), + ], + }, + "stats-prob": { + "name": "Statistics and Probability", + "topics": [ + ("Random Variables", "Random Variables"), + ("Discrete Probability Distributions", "Random Variables"), + ("Mean and Variance of Discrete RV", "Random Variables"), + ("Normal Distribution", "Normal Distribution"), + ("Standard Normal Distribution and Z-scores", "Normal Distribution"), + ("Areas Under the Normal Curve", "Normal Distribution"), + ("Sampling Distributions", "Sampling and Estimation"), + ("Central Limit Theorem", "Sampling and Estimation"), + ("Point Estimation", "Sampling and Estimation"), + ("Confidence Intervals", "Sampling and Estimation"), + ("Hypothesis Testing Concepts", "Hypothesis Testing"), + ("T-test", "Hypothesis Testing"), + ("Z-test", "Hypothesis Testing"), + ("Correlation and Regression", "Correlation and Regression"), + ], + }, + "pre-calc": { + "name": "Pre-Calculus", + "topics": [ + ("Conic Sections - Parabola", "Analytic Geometry"), + ("Conic Sections - Ellipse", "Analytic Geometry"), + ("Conic Sections - Hyperbola", "Analytic Geometry"), + ("Conic Sections - Circle", "Analytic Geometry"), + ("Systems of Nonlinear Equations", "Analytic Geometry"), + ("Sequences and Series", "Series and Induction"), + ("Arithmetic Sequences", "Series and Induction"), + ("Geometric Sequences", "Series and Induction"), + ("Mathematical Induction", "Series and Induction"), + ("Binomial Theorem", "Series and Induction"), + ("Angles and Unit Circle", "Trigonometry"), + ("Trigonometric Functions", "Trigonometry"), + ("Trigonometric Identities", "Trigonometry"), + ("Sum and Difference Formulas", "Trigonometry"), + ("Inverse Trigonometric Functions", "Trigonometry"), + ("Polar Coordinates", "Trigonometry"), + ], + }, + "basic-calc": { + "name": "Basic Calculus", + "topics": [ + ("Limits of Functions", "Limits"), + ("Limit Theorems", "Limits"), + ("One-Sided Limits", "Limits"), + ("Infinite Limits and Limits at Infinity", "Limits"), + ("Continuity of Functions", "Limits"), + ("Definition of the Derivative", "Derivatives"), + ("Differentiation Rules", "Derivatives"), + ("Chain Rule", "Derivatives"), + ("Implicit Differentiation", "Derivatives"), + ("Higher-Order Derivatives", "Derivatives"), + ("Related Rates", "Derivatives"), + ("Extrema and the First Derivative Test", "Derivatives"), + ("Concavity and the Second Derivative Test", "Derivatives"), + ("Optimization Problems", "Derivatives"), + ("Antiderivatives and Indefinite Integrals", "Integration"), + ("Definite Integrals and the FTC", "Integration"), + ("Integration by Substitution", "Integration"), + ("Area Under a Curve", "Integration"), + ], + }, +} + + +@app.get("/api/analytics/topic-mastery") +async def topic_mastery_analytics( + request: Request, + teacherId: str = Query(..., description="Teacher UID"), + classId: Optional[str] = Query(None, description="Optional class ID filter"), + classSectionId: Optional[str] = Query(None, description="Optional class section ID filter"), +): + """ + Aggregate per-topic mastery statistics for a teacher's class. + Returns topic-level averages, attempt counts, and mastery status. + """ + try: + user = get_current_user(request) + if user.role != "admin" and teacherId != user.uid: + raise HTTPException(status_code=403, detail="Forbidden for this teacher scope") + if not (_firebase_ready and firebase_firestore): + raise HTTPException(status_code=503, detail="Firestore unavailable") + + effective_class_section_id = (classSectionId or classId or "").strip().lower() or None + + topic_lookup: Dict[str, Dict[str, str]] = {} + for subject_id, subject_meta in _SHS_TOPICS.items(): + for topic_name, unit in subject_meta.get("topics", []): + canonical_name = _canonicalize_topic_label(topic_name) + if not canonical_name: + continue + topic_lookup[_normalize_topic_key(canonical_name)] = { + "topicName": canonical_name, + "subjectId": subject_id, + "unit": unit, + } + + def resolve_topic_meta(raw_topic: str) -> Dict[str, str]: + canonical = _canonicalize_topic_label(raw_topic) + key = _normalize_topic_key(canonical) + if key in topic_lookup: + return topic_lookup[key] + return { + "topicName": canonical or "General Performance", + "subjectId": "gen-math", + "unit": "Imported Curriculum", + } + + def extract_topics_from_record(row: Dict[str, Any]) -> List[str]: + candidates: List[str] = [] + assessment_name = _canonicalize_topic_label(str(row.get("assessmentName") or "").strip()) + if assessment_name and assessment_name.lower() != "general-assessment": + candidates.append(assessment_name) + + unknown_fields = row.get("unknownFields") or {} + if isinstance(unknown_fields, dict): + for key, value in unknown_fields.items(): + if any(token in str(key).lower() for token in ("topic", "unit", "lesson", "skill", "competency")): + fragments = re.split(r"[,;|]+", str(value or "")) + for fragment in fragments: + label = _canonicalize_topic_label(fragment.strip()) + if label: + candidates.append(label) + + deduped: List[str] = [] + seen_keys: Set[str] = set() + for candidate in candidates: + key = _normalize_topic_key(candidate) + if not key or key in seen_keys: + continue + seen_keys.add(key) + deduped.append(candidate) + + if not deduped: + return ["General Performance"] + return deduped + + client = firebase_firestore.client() + records_query = client.collection("normalizedClassRecords").where("teacherId", "==", teacherId) + material_query = client.collection("courseMaterials").where("teacherId", "==", teacherId) + + docs = list(records_query.limit(3000).stream()) + student_ids: Set[str] = set() + topic_stats: Dict[str, Dict[str, Any]] = {} + fallback_topic_rows = 0 + + for doc in docs: + row = doc.to_dict() or {} + resolved_class_section_id, _, _, _ = _resolve_import_class_context( + class_section_id=str(row.get("classSectionId") or "").strip() or None, + class_name=str(row.get("className") or "").strip() or None, + ) + if effective_class_section_id and resolved_class_section_id != effective_class_section_id: + continue + + student_identity = ( + str(row.get("studentId") or "").strip() + or str(row.get("lrn") or "").strip() + or str(row.get("email") or "").strip().lower() + or re.sub(r"\s+", "_", str(row.get("name") or "").strip().lower()) + ) + if student_identity: + student_ids.add(student_identity) + + avg_quiz_score = float(row.get("avgQuizScore") or 0.0) + extracted_topics = extract_topics_from_record(row) + if extracted_topics == ["General Performance"]: + fallback_topic_rows += 1 + + for topic_label in extracted_topics: + metadata = resolve_topic_meta(topic_label) + topic_key = _normalize_topic_key(metadata["topicName"]) + if topic_key not in topic_stats: + topic_stats[topic_key] = { + "topicName": metadata["topicName"], + "subjectId": metadata["subjectId"], + "unit": metadata["unit"], + "scores": [], + "students": set(), + "studentsAbove85": 0, + } + + topic_stats[topic_key]["scores"].append(avg_quiz_score) + if student_identity: + if student_identity not in topic_stats[topic_key]["students"] and avg_quiz_score >= 85: + topic_stats[topic_key]["studentsAbove85"] += 1 + topic_stats[topic_key]["students"].add(student_identity) + + material_docs = list(material_query.limit(200).stream()) + for doc in material_docs: + data = doc.to_dict() or {} + resolved_material_section_id, _, _, _ = _resolve_import_class_context( + class_section_id=str(data.get("classSectionId") or "").strip() or None, + class_name=str(data.get("className") or "").strip() or None, + ) + if effective_class_section_id and resolved_material_section_id != effective_class_section_id: + continue + + for topic in data.get("topics") or []: + topic_title = str((topic or {}).get("title") or "").strip() + if not topic_title: + continue + metadata = resolve_topic_meta(topic_title) + topic_key = _normalize_topic_key(metadata["topicName"]) + if topic_key not in topic_stats: + topic_stats[topic_key] = { + "topicName": metadata["topicName"], + "subjectId": metadata["subjectId"], + "unit": metadata["unit"], + "scores": [], + "students": set(), + "studentsAbove85": 0, + } + + total_students = len(student_ids) + topics_payload: List[Dict[str, Any]] = [] + mastered_count = 0 + needs_attention_count = 0 + + for topic_data in topic_stats.values(): + scores = topic_data.get("scores") or [] + students_attempted = len(topic_data.get("students") or set()) + class_average = float(sum(scores) / len(scores)) if scores else 0.0 + mastery_percentage = ( + float(topic_data.get("studentsAbove85") or 0) / float(total_students) * 100.0 + if total_students > 0 + else 0.0 + ) + + if students_attempted == 0: + mastery_status = "no_data" + elif mastery_percentage >= 75.0: + mastery_status = "mastered" + elif class_average >= 65.0 or mastery_percentage >= 40.0: + mastery_status = "on_track" + else: + mastery_status = "needs_attention" + + if mastery_status == "mastered": + mastered_count += 1 + if mastery_status == "needs_attention": + needs_attention_count += 1 + + topics_payload.append( + { + "topicName": topic_data["topicName"], + "subjectId": topic_data["subjectId"], + "unit": topic_data["unit"], + "classAverage": round(class_average, 1), + "studentsAttempted": students_attempted, + "totalStudents": total_students, + "studentsAbove85": int(topic_data.get("studentsAbove85") or 0), + "masteryPercentage": round(mastery_percentage, 1), + "masteryStatus": mastery_status, + "isExcluded": False, + } + ) + + topics_payload.sort(key=lambda item: (item["masteryStatus"], item["topicName"])) + + warnings: List[str] = [] + if fallback_topic_rows > 0: + warnings.append( + "Some records did not contain explicit topic/assessment columns; fallback topic context was applied." + ) + + return { + "topics": topics_payload, + "summary": { + "totalTopicsTracked": len(topics_payload), + "masteredCount": mastered_count, + "needsAttentionCount": needs_attention_count, + "excludedCount": 0, + "fallbackTopicRows": fallback_topic_rows, + }, + "warnings": warnings, + } + except Exception as e: + logger.error(f"Topic mastery analytics error: {e}") + raise HTTPException(status_code=500, detail=f"Topic mastery error: {str(e)}") + + +# ─── Automation Engine Endpoints ────────────────────────────── + + +@app.post("/api/automation/diagnostic-completed", response_model=AutomationResult) +async def automation_diagnostic_completed(payload: DiagnosticCompletionPayload): + """ + Trigger automation pipeline after a student completes the diagnostic. + Classifies risk per subject, generates learning path, creates + remedial quizzes, and produces teacher intervention recommendations. + """ + try: + logger.info(f"Automation trigger: diagnostic_completed for {payload.studentId}") + result = await automation_engine.handle_diagnostic_completion(payload) + return result + except Exception as e: + logger.error(f"Automation diagnostic error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Automation error: {str(e)}") + + +@app.post("/api/automation/quiz-submitted", response_model=AutomationResult) +async def automation_quiz_submitted(payload: QuizSubmissionPayload): + """ + Trigger automation after any quiz / assessment submission. + Recalculates risk for the subject and determines status changes. + """ + try: + logger.info(f"Automation trigger: quiz_submitted by {payload.studentId}") + result = await automation_engine.handle_quiz_submission(payload) + return result + except Exception as e: + logger.error(f"Automation quiz error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Automation error: {str(e)}") + + +@app.post("/api/automation/student-enrolled", response_model=AutomationResult) +async def automation_student_enrolled(payload: StudentEnrollmentPayload): + """ + Trigger automation when a new student account is created. + Initialises progress tracking, gamification, and flags diagnostic as pending. + """ + try: + logger.info(f"Automation trigger: student_enrolled for {payload.studentId}") + result = await automation_engine.handle_student_enrollment(payload) + return result + except Exception as e: + logger.error(f"Automation enrollment error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Automation error: {str(e)}") + + +@app.post("/api/automation/data-imported", response_model=AutomationResult) +async def automation_data_imported(payload: DataImportPayload): + """ + Trigger automation after a teacher uploads external data. + Recalculates risk for all affected students and flags status changes. + """ + try: + logger.info(f"Automation trigger: data_imported by teacher {payload.teacherId}") + result = await automation_engine.handle_data_import(payload) + return result + except Exception as e: + logger.error(f"Automation import error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Automation error: {str(e)}") + + +@app.post("/api/automation/content-updated", response_model=AutomationResult) +async def automation_content_updated(payload: ContentUpdatePayload): + """ + Trigger automation after admin CRUD on curriculum content. + Logs the change and notifies affected teachers. + """ + try: + logger.info(f"Automation trigger: content_updated by admin {payload.adminId}") + result = await automation_engine.handle_content_update(payload) + return result + except Exception as e: + logger.error(f"Automation content error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Automation error: {str(e)}") + + +# ─── Diagnostic Test Endpoints ───────────────────────────────── + +async def _generate_diagnostic_questions( + strand: str, + grade_level: str, + num_questions: int, +) -> List[DiagnosticQuestion]: + """Generate diagnostic test questions using LLM based on DepEd curriculum with RAG.""" + + topics = DEPD_ED_COMPETENCY_DOMAINS.get(strand, {}).get(grade_level, []) + if not topics: + topics = DEPD_ED_COMPETENCY_DOMAINS.get("STEM", {}).get("Grade 11", []) + + topic_list = "\n".join([f"- {t}" for t in topics[:10]]) + + curriculum_chunks = retrieve_curriculum_context( + query=f"{topics[0] if topics else strand} examples problems {grade_level}", + subject="General Mathematics", + top_k=3, + ) + + curriculum_context = "" + for chunk in curriculum_chunks: + source = chunk.get("source_file", "unknown") + content = chunk.get("content", "")[:500] + curriculum_context += f"[Source: {source}]\n{content}\n\n---\n\n" + + rag_instruction = "" + if curriculum_context: + rag_instruction = f"""CURRICULUM REFERENCE: +{curriculum_context} + +Use these examples as reference. Do not copy directly.""" + + prompt = f"""You are MathPulse AI's Diagnostic Test Generator. Generate {num_questions} multiple-choice questions for a Filipino Senior High School student (Strand: {strand}, Grade: {grade_level}). + +Based on these DepEd SHS curriculum competencies: +{topic_list} + +{rag_instruction} + +Generate questions in this strict JSON format (no other text): +[ + {{ + "question_id": "DX-", + "competency_code": "TOPIC-SUBTOPIC-01", + "domain": "Domain Name", + "topic": "Specific Topic", + "difficulty": "easy|medium|hard", + "bloom_level": "remembering|understanding|applying|analyzing", + "question_text": "Question text in Filipino context", + "options": {{"A": "...", "B": "...", "C": "...", "D": "..."}}, + "correct_answer": "A|B|C|D", + "solution_hint": "Brief solution hint (1-2 sentences)", + "curriculum_reference": "DepEd SHS [Strand] Q[X] - [Topic]" + }} +] + +Distribution: 40% easy, 40% medium, 20% hard. +Use Filipino real-life context (peso amounts, SSS/PhilHealth/BIR, local scenarios). +Distractors must be plausible but clearly wrong. +Return ONLY the JSON array, no other text.""" + + try: + messages = [ + {"role": "system", "content": "You are a math question generator. Return ONLY valid JSON."}, + {"role": "user", "content": prompt}, + ] + response = await call_hf_chat_async(messages, max_tokens=4096, temperature=0.3, task_type="quiz") + + import re + json_match = re.search(r'\[.*\]', response, re.DOTALL) + if json_match: + questions_data = json.loads(json_match.group()) + else: + questions_data = json.loads(response) + + questions = [] + for q in questions_data[:num_questions]: + questions.append(DiagnosticQuestion(**q)) + + return questions + except Exception as e: + logger.error(f"Diagnostic question generation error: {e}") + raise + + +async def _analyze_diagnostic_risk( + responses: List[Dict[str, Any]], + total_items: int, + total_score: int, +) -> Dict[str, Any]: + """Analyze student performance and generate risk profile.""" + domain_scores: Dict[str, Dict[str, Any]] = {} + domain_responses: Dict[str, List[Dict[str, Any]]] = {} + + for resp in responses: + domain = resp.get("domain", "Unknown") + if domain not in domain_responses: + domain_responses[domain] = [] + domain_responses[domain].append(resp) + + for domain, resp_list in domain_responses.items(): + correct = sum(1 for r in resp_list if r.get("is_correct", False)) + total = len(resp_list) + pct = (correct / total * 100) if total > 0 else 0 + + mastery = "mastered" if pct >= 80 else "developing" if pct >= 60 else "beginning" + domain_scores[domain] = { + "correct": correct, + "total": total, + "percentage": round(pct, 1), + "mastery_level": mastery, + } + + weak_domains = [ + d for d, data in domain_scores.items() + if data["percentage"] < 60 + ] + + critical_gaps = [] + competency_attempts: Dict[str, List[bool]] = {} + for resp in responses: + comp_code = resp.get("competency_code", "") + if comp_code not in competency_attempts: + competency_attempts[comp_code] = [] + competency_attempts[comp_code].append(resp.get("is_correct", False)) + + for comp_code, results in competency_attempts.items(): + correct_count = sum(1 for r in results if r) + if len(results) >= 2 and correct_count == 0: + critical_gaps.append(comp_code) + + overall_pct = (total_score / total_items * 100) if total_items > 0 else 0 + + if overall_pct >= 75 and len(critical_gaps) == 0: + overall_risk = "low" + elif overall_pct >= 55 or len(critical_gaps) <= 2: + overall_risk = "moderate" + elif overall_pct >= 40 or len(critical_gaps) <= 4: + overall_risk = "high" + else: + overall_risk = "critical" + + intervention_messages = { + "low": "Great job! You have a solid foundation. Keep practicing to maintain your skills!", + "moderate": "You're making good progress. Focus on the topics where you need more practice.", + "high": "Don't worry! With focused practice on your weak areas, you'll improve quickly.", + "critical": "Let's work on this together. Start with the basics and build up your confidence.", + } + + suggested_path = weak_domains[:3] if weak_domains else list(domain_scores.keys())[:3] + + return { + "overall_risk": overall_risk, + "overall_score_percent": round(overall_pct, 1), + "domain_scores": domain_scores, + "weak_domains": weak_domains, + "critical_gaps": critical_gaps, + "recommended_intervention": intervention_messages[overall_risk], + "suggested_learning_path": suggested_path, + } + + +def _save_diagnostic_to_firestore(result: DiagnosticResult) -> bool: + """Save diagnostic result to Firestore.""" + if not HAS_FIREBASE_ADMIN or not firebase_firestore: + logger.warning("Firebase not available for diagnostic save") + return False + + try: + db = firebase_firestore.client() + doc_ref = db.collection("diagnosticResults").document(result.user_id).collection("attempts").document(result.test_id) + doc_ref.set({ + "testId": result.test_id, + "takenAt": result.taken_at, + "strand": result.strand, + "gradeLevel": result.grade_level, + "totalItems": result.total_items, + "totalScore": result.total_score, + "percentageScore": result.percentage_score, + "responses": result.responses, + "domainScores": result.domain_scores, + "riskProfile": result.risk_profile, + }) + + latest_ref = db.collection("users").document(result.user_id) + latest_ref.set({"latestDiagnosticTestId": result.test_id}, merge=True) + + return True + except Exception as e: + logger.error(f"Firestore diagnostic save error: {e}") + return False + + +@app.post("/api/diagnostic/generate", response_model=DiagnosticGenerateResponse) +async def generate_diagnostic_test(request: DiagnosticGenerateRequest): + """ + Generate a personalized diagnostic assessment for a student. + Questions are based on DepEd Strengthened SHS Curriculum. + """ + try: + test_id = f"DX-{uuid.uuid4().hex[:12]}" + + questions = await _generate_diagnostic_questions( + request.strand, + request.gradeLevel, + request.numQuestions, + ) + + stripped_questions = [] + for q in questions: + stripped_questions.append(DiagnosticQuestion( + question_id=q.question_id, + competency_code=q.competency_code, + domain=q.domain, + topic=q.topic, + difficulty=q.difficulty, + bloom_level=q.bloom_level, + question_text=q.question_text, + options=q.options, + correct_answer=q.correct_answer, + solution_hint="", + curriculum_reference=q.curriculum_reference, + )) + + metadata = { + "strand": request.strand, + "grade_level": request.gradeLevel, + "num_questions": len(questions), + "generated_at": datetime.now(timezone.utc).isoformat(), + } + + return DiagnosticGenerateResponse( + questions=stripped_questions, + test_id=test_id, + metadata=metadata, + ) + except Exception as e: + logger.error(f"Diagnostic generation error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Diagnostic generation error: {str(e)}") + + +@app.post("/api/diagnostic/submit", response_model=DiagnosticSubmitResponse) +async def submit_diagnostic_test(request: DiagnosticSubmitRequest): + """ + Submit diagnostic test responses, score them, and generate risk profile. + Results are saved to Firestore for use by other subsystems. + """ + try: + total_items = len(request.responses) + total_score = 0 + scored_responses = [] + + for resp in request.responses: + is_correct = resp.get("student_answer", "") == resp.get("correct_answer", "") + if is_correct: + total_score += 1 + scored_responses.append({ + "question_id": resp.get("question_id"), + "competency_code": resp.get("competency_code"), + "domain": resp.get("domain"), + "topic": resp.get("topic"), + "difficulty": resp.get("difficulty"), + "bloom_level": resp.get("bloom_level"), + "student_answer": resp.get("student_answer"), + "correct_answer": resp.get("correct_answer"), + "is_correct": is_correct, + "time_spent_seconds": resp.get("time_spent_seconds", 0), + }) + + risk_profile = await _analyze_diagnostic_risk( + scored_responses, + total_items, + total_score, + ) + + domain_scores = risk_profile.get("domain_scores", {}) + + result = DiagnosticResult( + user_id=request.user_id, + test_id=request.test_id, + taken_at=datetime.now(timezone.utc), + strand=request.strand, + grade_level=request.grade_level, + total_items=total_items, + total_score=total_score, + percentage_score=round(total_score / total_items * 100, 1), + responses=scored_responses, + domain_scores=domain_scores, + risk_profile=risk_profile, + ) + + _save_diagnostic_to_firestore(result) + + return DiagnosticSubmitResponse( + success=True, + result=result, + risk_profile=risk_profile, + domain_scores=domain_scores, + redirect_to="/dashboard", + ) + except Exception as e: + logger.error(f"Diagnostic submit error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Diagnostic submit error: {str(e)}") + + +@app.get("/api/diagnostic/results/{user_id}", response_model=DiagnosticResultsResponse) +async def get_diagnostic_results(user_id: str): + """ + Fetch diagnostic test results for a student. + Returns all attempts with risk profiles. + """ + if not HAS_FIREBASE_ADMIN or not firebase_firestore: + return DiagnosticResultsResponse(success=False, results=[]) + + try: + db = firebase_firestore.client() + docs = db.collection("diagnosticResults").document(user_id).collection("attempts").stream() + + results = [] + for doc in docs: + data = doc.to_dict() + if data: + results.append(DiagnosticResult(**data)) + + results.sort(key=lambda x: x.taken_at, reverse=True) + + return DiagnosticResultsResponse(success=True, results=results) + except Exception as e: + logger.error(f"Diagnostic results fetch error: {e}") + return DiagnosticResultsResponse(success=False, results=[]) + + +# ─── DepEd Topic Registry for Lessons/Quizzes ───────────────────────────── + +DEPD_TOPIC_REGISTRY: Dict[str, Dict[str, str]] = { + "NA-WAGE-01": {"subject": "General Mathematics", "title": "Wages, Salaries, Overtime, Commissions, VAT", "quarter": "Q1"}, + "NA-SEQ-01": {"subject": "General Mathematics", "title": "Arithmetic Sequences and Series", "quarter": "Q1"}, + "NA-SEQ-02": {"subject": "General Mathematics", "title": "Geometric Sequences and Series", "quarter": "Q1"}, + "NA-SEQ-03": {"subject": "General Mathematics", "title": "Sigma Notation, Financial Applications", "quarter": "Q1"}, + "NA-FUNC-01": {"subject": "General Mathematics", "title": "Functions, Relations, Vertical Line Test", "quarter": "Q2"}, + "NA-FUNC-02": {"subject": "General Mathematics", "title": "Evaluating Functions, Operations, Composition", "quarter": "Q2"}, + "NA-FUNC-03": {"subject": "General Mathematics", "title": "One-to-One Functions, Inverse Functions", "quarter": "Q2"}, + "NA-FUNC-04": {"subject": "General Mathematics", "title": "Piecewise Functions", "quarter": "Q2"}, + "NA-EXP-01": {"subject": "General Mathematics", "title": "Exponential Functions, Equations, Inequalities", "quarter": "Q2"}, + "NA-LOG-01": {"subject": "General Mathematics", "title": "Logarithmic Functions", "quarter": "Q2"}, + "MG-TRIG-01": {"subject": "General Mathematics", "title": "Trigonometric Ratios, Right Triangles", "quarter": "Q3"}, + "MG-TRIG-02": {"subject": "General Mathematics", "title": "Oblique Triangles, Heron's Formula", "quarter": "Q3"}, + "MG-MEAS-01": {"subject": "General Mathematics", "title": "Unit Conversion, Surface Area, Volume", "quarter": "Q2"}, + "DP-STAT-01": {"subject": "Statistics", "title": "Types of Data, Levels of Measurement", "quarter": "Q2"}, + "DP-STAT-02": {"subject": "Statistics", "title": "Measures of Central Tendency and Variability", "quarter": "Q2"}, + "DP-RV-01": {"subject": "Statistics", "title": "Random Variables (Discrete & Continuous)", "quarter": "Q3"}, + "DP-RV-02": {"subject": "Statistics", "title": "Probability Distributions, Mean, Variance, SD", "quarter": "Q3"}, + "DP-NORM-01": {"subject": "Statistics", "title": "Normal Distribution, Properties", "quarter": "Q3"}, + "DP-NORM-02": {"subject": "Statistics", "title": "Z-Scores, Standard Normal Table", "quarter": "Q3"}, + "DP-SAMP-01": {"subject": "Statistics", "title": "Sampling, Central Limit Theorem", "quarter": "Q3"}, + "DP-SAMP-02": {"subject": "Statistics", "title": "Sampling Distribution of Sample Means", "quarter": "Q3"}, + "NA-FIN-01": {"subject": "General Mathematics", "title": "Compound Interest, Maturity Value", "quarter": "Q4"}, + "NA-FIN-02": {"subject": "General Mathematics", "title": "Simple and General Annuities", "quarter": "Q4"}, + "NA-FIN-03": {"subject": "General Mathematics", "title": "Deferred Annuity, Fair Market Value", "quarter": "Q4"}, + "NA-FIN-04": {"subject": "General Mathematics", "title": "Business and Consumer Loans, Amortization", "quarter": "Q4"}, + "DP-HYP-01": {"subject": "Statistics", "title": "Hypothesis Testing: Null/Alternative, Types of Error", "quarter": "Q4"}, + "DP-HYP-02": {"subject": "Statistics", "title": "Z-Test and T-Test", "quarter": "Q4"}, + "DP-HYP-03": {"subject": "Statistics", "title": "Pearson r, Scatter Plots, Line of Best Fit", "quarter": "Q4"}, + "NA-LOGIC-01": {"subject": "General Mathematics", "title": "Logical Propositions, Connectives, Truth Tables", "quarter": "Q4"}, + "NA-LOGIC-02": {"subject": "General Mathematics", "title": "Conditional Propositions, Tautologies", "quarter": "Q4"}, + "BM-FDP-01": {"subject": "Business Mathematics", "title": "Fractions, Decimals, Percent Conversions", "quarter": "Q1"}, + "BM-FDP-02": {"subject": "Business Mathematics", "title": "Proportion: Direct, Inverse, Partitive", "quarter": "Q1"}, + "BM-BUS-01": {"subject": "Business Mathematics", "title": "Markup, Margin, Trade Discounts, VAT", "quarter": "Q1"}, + "BM-BUS-02": {"subject": "Business Mathematics", "title": "Profit, Loss, Break-even Point", "quarter": "Q1"}, + "BM-COMM-01": {"subject": "Business Mathematics", "title": "Straight Commission, Salary Plus Commission", "quarter": "Q2"}, + "BM-COMM-02": {"subject": "Business Mathematics", "title": "Commission on Cash and Installment Basis", "quarter": "Q2"}, + "BM-COMM-03": {"subject": "Business Mathematics", "title": "Down Payment, Gross Balance", "quarter": "Q2"}, + "BM-INT-01": {"subject": "Business Mathematics", "title": "Simple Interest, Compound Interest", "quarter": "Q2"}, + "BM-INT-02": {"subject": "Business Mathematics", "title": "Solving Problems with Interest and Commission", "quarter": "Q2"}, + "BM-SW-01": {"subject": "Business Mathematics", "title": "Salary vs. Wage, Income", "quarter": "Q2"}, + "BM-SW-02": {"subject": "Business Mathematics", "title": "Employee Benefits: Taxable vs. Nontaxable", "quarter": "Q2"}, + "BM-SW-03": {"subject": "Business Mathematics", "title": "Mandatory Deductions: SSS, PhilHealth, Pag-IBIG", "quarter": "Q2"}, + "BM-SW-04": {"subject": "Business Mathematics", "title": "Overtime Pay Computation (Labor Code)", "quarter": "Q2"}, + "BM-SW-05": {"subject": "Business Mathematics", "title": "E-Spreadsheet for Payroll", "quarter": "Q2"}, + "BM-MORT-01": {"subject": "Business Mathematics", "title": "Mortgage, Amortization, Monthly Payment", "quarter": "Q2"}, + "BM-DATA-01": {"subject": "Business Mathematics", "title": "Data Presentation: Tables, Bar, Line, Pie Charts", "quarter": "Q2"}, + "BM-DATA-02": {"subject": "Business Mathematics", "title": "Analyzing Business Data with Excel", "quarter": "Q2"}, + "SP-RV-01": {"subject": "Statistics & Probability", "title": "Random Variables, Discrete vs. Continuous", "quarter": "Q1"}, + "SP-RV-02": {"subject": "Statistics & Probability", "title": "Probability Distribution, Mean, Variance, SD", "quarter": "Q1"}, + "SP-NORM-01": {"subject": "Statistics & Probability", "title": "Normal Curve Properties", "quarter": "Q1"}, + "SP-NORM-02": {"subject": "Statistics & Probability", "title": "Z-Scores, Standard Normal Table", "quarter": "Q1"}, + "SP-NORM-03": {"subject": "Statistics & Probability", "title": "Applying Normal Distribution", "quarter": "Q1"}, + "SP-SAMP-01": {"subject": "Statistics & Probability", "title": "Types of Random Sampling", "quarter": "Q2"}, + "SP-SAMP-02": {"subject": "Statistics & Probability", "title": "Sampling Distribution of Sample Means", "quarter": "Q2"}, + "SP-SAMP-03": {"subject": "Statistics & Probability", "title": "Central Limit Theorem", "quarter": "Q2"}, + "SP-HYP-01": {"subject": "Statistics & Probability", "title": "Hypothesis Testing: H0 and Ha", "quarter": "Q2"}, + "SP-HYP-02": {"subject": "Statistics & Probability", "title": "Level of Significance, Type I and II Errors", "quarter": "Q2"}, + "SP-HYP-03": {"subject": "Statistics & Probability", "title": "Z-Test for Known Variance", "quarter": "Q2"}, + "SP-HYP-04": {"subject": "Statistics & Probability", "title": "T-Test for Unknown Variance", "quarter": "Q2"}, + "SP-HYP-05": {"subject": "Statistics & Probability", "title": "Z-Test and T-Test for Proportion", "quarter": "Q2"}, + "SP-CORR-01": {"subject": "Statistics & Probability", "title": "Pearson r, Scatter Plots", "quarter": "Q2"}, + "SP-CORR-02": {"subject": "Statistics & Probability", "title": "Line of Best Fit, Regression", "quarter": "Q2"}, +} + + +# ─── Diagnostic-Integrated Lesson Generation ───────────────────── + +class DiagnosticLessonRequest(BaseModel): + student_id: str + topic_id: str + mastery_level: str = Field(default="beginning") + strand: str = Field(default="STEM") + grade_level: str = Field(default="Grade 11") + + +class DiagnosticLessonSection(BaseModel): + type: str + title: Optional[str] = None + content: str + formula: Optional[str] = None + visual_hint: Optional[str] = None + problem: Optional[str] = None + solution_steps: Optional[List[Dict[str, Any]]] = None + final_answer: Optional[str] = None + prompt: Optional[str] = None + hint: Optional[str] = None + answer: Optional[str] = None + + +class DiagnosticLessonResponse(BaseModel): + lesson_id: str + topic_id: str + subject: str + title: str + grade_level: str + strand: str + estimated_minutes: int + mastery_target: str + learning_objectives: List[str] + sections: List[DiagnosticLessonSection] + summary: str + real_life_connection: str + next_topic_id: Optional[str] + prerequisite_topic_ids: List[str] + + +@app.post("/api/lesson/diagnostic", response_model=DiagnosticLessonResponse) +async def generate_diagnostic_lesson(request: DiagnosticLessonRequest): + """ + Generate personalized lesson based on diagnostic test results. + Adjusts content difficulty based on student's mastery level. + Uses RAG to inject DepEd curriculum content. + """ + try: + topic_info = DEPD_TOPIC_REGISTRY.get(request.topic_id, {}) + subject = topic_info.get("subject", "General Mathematics") + title = topic_info.get("title", request.topic_id) + + curriculum_chunks = retrieve_curriculum_context( + query=f"{title} {request.topic_id} examples problems exercises", + subject=subject, + top_k=4, + ) + + curriculum_context = "" + for chunk in curriculum_chunks: + source = chunk.get("source_file", "unknown") + content = chunk.get("content", "")[:800] + curriculum_context += f"[Source: {source}]\n{content}\n\n---\n\n" + + mastery_adjustments = { + "beginning": "Use extra-simple language, 3 worked examples, more hints.", + "developing": "Standard pacing, 2 worked examples.", + "mastered": "Fast-track with 1 worked example and a challenge problem.", + } + + rag_instruction = "" + if curriculum_context: + rag_instruction = f"""REFERENCE CURRICULUM CONTENT (from DepEd modules): +{curriculum_context} + +IMPORTANT: Base your lesson STRICTLY on the curriculum content above. Do not invent formulas or examples.""" + + prompt = f"""Generate a complete lesson for topic {request.topic_id}: {title}. + +Student Context: +- Strand: {request.strand} +- Grade: {request.grade_level} +- Mastery Level: {request.mastery_level} ({mastery_adjustments.get(request.mastery_level, '')}) + +{rag_instruction} + +Use Filipino context (₱, local scenarios). +Follow SDO Navotas step-by-step: "Given → Formula → Substitute → Compute → Conclude" + +Return ONLY this exact JSON (no other text): +{{ + "lesson_id": "LSN-{uuid.uuid4().hex[:8]}", + "topic_id": "{request.topic_id}", + "subject": "{subject}", + "title": "{title}", + "grade_level": "{request.grade_level}", + "strand": "{request.strand}", + "estimated_minutes": 20, + "mastery_target": "mastered", + "learning_objectives": ["By the end, you will be able to..."], + "sections": [ + {{"type": "hook", "content": "Relatable Filipino intro (2-3 sentences)"}}, + {{"type": "concept", "title": "...", "content": "Core explanation", "formula": "LaTeX or null", "visual_hint": "description or null"}}, + {{"type": "worked_example", "title": "Example 1", "problem": "...", "solution_steps": [{{"step": 1, "explanation": "...", "math": "LaTeX or null"}}], "final_answer": "..."}}, + {{"type": "try_it", "prompt": "Your turn!", "problem": "...", "hint": "Think about...", "answer": "...", "solution_steps": []}} + ], + "summary": "3-sentence recap", + "real_life_connection": "1 sentence to Filipino career", + "next_topic_id": "next topic ID or null", + "prerequisite_topic_ids": ["prereq topic IDs"] +}}""" + + messages = [ + {"role": "system", "content": "You are a DepEd curriculum lesson designer. Return ONLY valid JSON."}, + {"role": "user", "content": prompt}, + ] + response = await call_hf_chat_async(messages, max_tokens=4096, temperature=0.3, task_type="lesson") + + import re + json_match = re.search(r'\{.*\}', response, re.DOTALL) + if json_match: + lesson_data = json.loads(json_match.group()) + else: + lesson_data = json.loads(response) + + return DiagnosticLessonResponse(**lesson_data) + except Exception as e: + logger.error(f"Diagnostic lesson generation error: {e}") + raise HTTPException(status_code=500, detail=f"Lesson generation error: {str(e)}") + + +# ─── Consolidated Lesson Generator (reads from diagnostic) ───────────── + +class LessonsGenerateRequest(BaseModel): + student_id: str + topic_id: str + strand: str = Field(default="STEM") + grade_level: str = Field(default="Grade 11") + + +@app.post("/api/lessons/generate", response_model=DiagnosticLessonResponse) +async def generate_lesson_from_diagnostic(request: LessonsGenerateRequest): + """ + Generate a personalized lesson by reading mastery_level from the + student's diagnostic results in Firestore. Falls back to 'beginning' + if no diagnostic data exists. + """ + mastery_level = "beginning" + + if HAS_FIREBASE_ADMIN and firebase_firestore: + try: + db = firebase_firestore.client() + user_doc = db.collection("users").document(request.student_id).get() + if user_doc.exists: + user_data = user_doc.to_dict() or {} + diag_id = user_data.get("latestDiagnosticTestId", "") + if diag_id: + diag_doc = ( + db.collection("diagnosticResults") + .document(request.student_id) + .collection("attempts") + .document(diag_id) + .get() + ) + if diag_doc.exists: + diag_data = diag_doc.to_dict() or {} + domain_scores = diag_data.get("domainScores", {}) + for domain, score_data in domain_scores.items(): + ml = score_data.get("mastery_level", "") + if ml: + mastery_level = ml + break + except Exception as diag_err: + logger.debug(f"Could not read diagnostic mastery for lesson: {diag_err}") + + return await generate_diagnostic_lesson( + DiagnosticLessonRequest( + student_id=request.student_id, + topic_id=request.topic_id, + mastery_level=mastery_level, + strand=request.strand, + grade_level=request.grade_level, + ) + ) + + +# ─── Progress Evaluation Endpoint ───────────────────────────────── + +class ProgressEvaluateRequest(BaseModel): + student_id: str + quiz_id: str + topic_id: str + mastery_level_before: str + items: List[Dict[str, Any]] + previous_attempts: int = Field(default=0) + current_streak_days: int = Field(default=0) + + +class ProgressEvaluateResponse(BaseModel): + new_mastery_level: str + mastery_changed: bool + score_percent: float + xp_earned: int + xp_breakdown: Dict[str, int] + badges_unlocked: List[str] + performance_feedback: str + error_analysis: List[Dict[str, Any]] + next_action: str + next_topic_id: Optional[str] + motivational_message: str + teacher_flag: Optional[Dict[str, Any]] + + +@app.post("/api/progress/evaluate", response_model=ProgressEvaluateResponse) +async def evaluate_progress(request: ProgressEvaluateRequest): + """ + Evaluate quiz performance, update mastery, award XP. + Called after every quiz submission. + """ + try: + total_items = len(request.items) + correct_count = sum(1 for item in request.items if item.get("is_correct", False)) + score_percent = (correct_count / total_items * 100) if total_items > 0 else 0 + + mastery_changed = False + new_level = request.mastery_level_before + prev = request.mastery_level_before + + applying_level_correct = sum( + 1 for item in request.items + if item.get("is_correct", False) and item.get("bloom_level", "") in ("applying", "analyzing", "evaluating") + ) + analyzing_level_correct = sum( + 1 for item in request.items + if item.get("is_correct", False) and item.get("bloom_level", "") in ("analyzing", "evaluating", "creating") + ) + + if prev == "beginning" and score_percent >= 60 and applying_level_correct >= 2: + new_level = "developing" + mastery_changed = True + elif prev == "developing" and score_percent >= 80 and analyzing_level_correct >= 1: + new_level = "mastered" + mastery_changed = True + + xp_base = 0 + xp_streak = 0 + xp_mastery = 0 + xp_other = 0 + + for item in request.items: + diff = item.get("difficulty", "easy") + if item.get("is_correct", False): + if diff == "easy": + xp_base += 5 + elif diff == "medium": + xp_base += 10 + elif diff == "hard": + xp_base += 20 + + xp_streak = min(30, 5 * request.current_streak_days) + + if mastery_changed: + xp_mastery = 50 + + if score_percent == 100 and request.previous_attempts == 0: + xp_other += 30 + + if request.previous_attempts >= 1 and score_percent > 60: + xp_other += 15 + + xp_total = xp_base + xp_streak + xp_mastery + xp_other + + error_analysis = [] + for item in request.items: + if not item.get("is_correct", False): + error_analysis.append({ + "item_id": item.get("item_id", ""), + "student_answer": item.get("student_answer", ""), + "correct_answer": item.get("correct_answer", ""), + "explanation": "Check your steps for this type of problem.", + }) + + next_action = "continue_learning_path" + if score_percent < 40 and request.previous_attempts >= 3: + next_action = "teacher_flag" + elif score_percent < 60: + next_action = "retry_quiz" + + next_topics = list(DEPD_TOPIC_REGISTRY.keys()) + current_idx = next_topics.index(request.topic_id) if request.topic_id in next_topics else 0 + next_topic_id = next_topics[current_idx + 1] if current_idx + 1 < len(next_topics) else None + + messages = { + "low": "Keep practicing! You're building momentum.", + "moderate": "Good progress! Focus on your weak areas.", + "high": "You're improving! Stay consistent.", + "critical": "Don't give up! One step at a time.", + } + motivational = messages.get(new_level, messages["low"]) + + if mastery_changed: + if new_level == "developing": + motivational = "Kaya mo yan! You're moving up!" + elif new_level == "mastered": + motivational = "Congratulations! Topic mastered!" + + teacher_flag = None + if score_percent < 40 and request.previous_attempts >= 3: + teacher_flag = {"reason": f"Score {score_percent}% after 3+ attempts", "severity": "high"} + + if HAS_FIREBASE_ADMIN and firebase_firestore: + try: + db = firebase_firestore.client() + topic_progress_ref = db.collection("studentProgress").document(request.student_id).collection("topics").document(request.topic_id) + topic_progress_ref.set({ + "mastery_level": new_level, + "quiz_attempts": firebase_firestore.Increment(1), + "best_score": max(score_percent, 0), + "xp_earned": firebase_firestore.Increment(xp_total), + "last_activity": firebase_firestore.SERVER_TIMESTAMP, + "error_patterns": [e.get("explanation", "") for e in error_analysis], + "teacher_flagged": teacher_flag is not None, + }, merge=True) + + stats_ref = db.collection("studentProgress").document(request.student_id).collection("stats").document("summary") + stats_ref.set({ + "total_xp": firebase_firestore.Increment(xp_total), + "current_streak_days": request.current_streak_days, + "topics_mastered": firebase_firestore.Increment(1) if mastery_changed else firebase_firestore.Increment(0), + }, merge=True) + except Exception as fs_err: + logger.warning(f"Firestore progress save failed: {fs_err}") + + return ProgressEvaluateResponse( + new_mastery_level=new_level, + mastery_changed=mastery_changed, + score_percent=round(score_percent, 1), + xp_earned=xp_total, + xp_breakdown={"base": xp_base, "mastery_bonus": xp_mastery, "streak_bonus": xp_streak, "other": xp_other}, + badges_unlocked=[], + performance_feedback=f"You got {correct_count}/{total_items} correct.", + error_analysis=error_analysis, + next_action=next_action, + next_topic_id=next_topic_id, + motivational_message=motivational, + teacher_flag=teacher_flag, + ) + except Exception as e: + logger.error(f"Progress evaluation error: {e}") + raise HTTPException(status_code=500, detail=f"Progress evaluation error: {str(e)}") + + +# ─── Adaptive Quiz Endpoint ───────────────────────────────────── + +class AdaptiveQuizRequest(BaseModel): + student_id: str + topic_id: str + recent_lesson_id: Optional[str] = None + strand: str = Field(default="STEM") + + +class AdaptiveQuizItem(BaseModel): + item_id: str + type: str + bloom_level: str + difficulty: str + question: str + options: Optional[Dict[str, str]] = None + correct_answer: str + acceptable_range: Optional[List[float]] = None + solution_hint: str + competency_code: str + curriculum_reference: str + + +class DiagnosticQuizResponse(BaseModel): + quiz_id: str + topic_id: str + mastery_target_after: str + items: List[AdaptiveQuizItem] + prev_score: Optional[float] + difficulty_distribution: Dict[str, int] + + +async def _resolve_mastery_and_prev_score( + student_id: str, + topic_id: str, +) -> tuple[str, Optional[float]]: + """Read mastery_level and prev_score from Firestore diagnostic and studentProgress.""" + mastery = "beginning" + prev_score: Optional[float] = None + + if not HAS_FIREBASE_ADMIN or not firebase_firestore: + return mastery, prev_score + + try: + db = firebase_firestore.client() + + topic_progress_doc = ( + db.collection("studentProgress") + .document(student_id) + .collection("topics") + .document(topic_id) + .get() + ) + if topic_progress_doc.exists: + tp_data = topic_progress_doc.to_dict() or {} + tp_mastery = str(tp_data.get("mastery_level", "")).strip() + if tp_mastery in ("beginning", "developing", "mastered"): + mastery = tp_mastery + prev_score_raw = tp_data.get("best_score") + if isinstance(prev_score_raw, (int, float)): + prev_score = float(prev_score_raw) + + user_doc = db.collection("users").document(student_id).get() + if user_doc.exists: + user_data = user_doc.to_dict() or {} + diag_id = user_data.get("latestDiagnosticTestId", "") + if diag_id: + diag_doc = ( + db.collection("diagnosticResults") + .document(student_id) + .collection("attempts") + .document(diag_id) + .get() + ) + if diag_doc.exists: + diag_data = diag_doc.to_dict() or {} + domain_scores = diag_data.get("domainScores", {}) + if not topic_progress_doc.exists: + for domain, score_data in domain_scores.items(): + ml = score_data.get("mastery_level", "") + if ml and ml in ("beginning", "developing", "mastered"): + mastery = ml + break + except Exception as e: + logger.debug(f"Could not resolve mastery/prev_score: {e}") + + return mastery, prev_score + + +def _calibrate_quiz_params(mastery_level: str, prev_score: Optional[float]) -> dict: + """Return item count and difficulty distribution based on mastery and history.""" + if mastery_level == "mastered": + count = 10 + distribution = {"easy": 10, "medium": 40, "hard": 50} + elif mastery_level == "developing": + count = 8 + distribution = {"easy": 30, "medium": 50, "hard": 20} + else: + count = 5 + distribution = {"easy": 60, "medium": 40, "hard": 0} + + if prev_score is not None and prev_score < 50: + distribution = { + "easy": min(80, distribution["easy"] + 20), + "medium": distribution["medium"], + "hard": max(0, distribution["hard"] - 20), + } + + return {"count": count, "distribution": distribution} + + +@app.post("/api/quiz/adaptive") +async def generate_adaptive_quiz(request: AdaptiveQuizRequest): + """ + Generate an adaptive practice quiz calibrated to the student's mastery level. + Reads mastery_level and prev_score from Firestore, auto-calibrates difficulty. + """ + try: + mastery, prev_score = await _resolve_mastery_and_prev_score( + request.student_id, + request.topic_id, + ) + + params = _calibrate_quiz_params(mastery, prev_score) + count = params["count"] + distribution = params["distribution"] + topic_info = DEPD_TOPIC_REGISTRY.get(request.topic_id, {}) + subject = topic_info.get("subject", "General Mathematics") + title = topic_info.get("title", request.topic_id) + + curriculum_chunks = retrieve_curriculum_context( + query=f"{title} {request.topic_id} practice problems exercises", + subject=subject, + top_k=3, + ) + curriculum_context = "" + for chunk in curriculum_chunks: + source = chunk.get("source_file", "unknown") + content = chunk.get("content", "")[:500] + curriculum_context += f"[Source: {source}]\n{content}\n\n---\n\n" + + quiz_id = f"QZ-{uuid.uuid4().hex[:12]}" + + rag_instr = "" + if curriculum_context: + rag_instr = f"""REFERENCE CURRICULUM: +{curriculum_context} + +Base questions on this content. Do not copy directly.""" + + items_json = json.dumps([]) + + try: + quiz_prompt = f"""Generate {count} quiz items for topic "{title}" (ID: {request.topic_id}). + +Mastery Level: {mastery} +Difficulty Distribution: Easy={distribution['easy']}%, Medium={distribution['medium']}%, Hard={distribution['hard']}% +Item types: mix multiple_choice, fill_in_the_blank, and word_problem. + +{rag_instr} + +Use Filipino context. +Return ONLY this strict JSON array: +[ + {{ + "type": "multiple_choice|fill_in_the_blank|word_problem", + "bloom_level": "remembering|understanding|applying|analyzing", + "difficulty": "easy|medium|hard", + "question": "...", + "options": {{"A": "...", "B": "...", "C": "...", "D": "..."}}, + "correct_answer": "B", + "acceptable_range": null, + "solution_hint": "Short hint", + "competency_code": "{request.topic_id}", + "curriculum_reference": "DepEd SHS" + }} +]""" + messages = [ + {"role": "system", "content": "You are a quiz generator. Return ONLY valid JSON."}, + {"role": "user", "content": quiz_prompt}, + ] + response = await call_hf_chat_async(messages, max_tokens=4096, temperature=0.3, task_type="quiz") + items_json = response + except Exception as llm_err: + logger.error(f"Adaptive quiz LLM error: {llm_err}") + + import re + json_match = re.search(r'\[.*\]', items_json, re.DOTALL) + if json_match: + raw_items = json.loads(json_match.group()) + else: + raw_items = json.loads(items_json) if items_json.strip().startswith('[') else [] + + items: List[AdaptiveQuizItem] = [] + for i, qi in enumerate(raw_items[:count]): + items.append(AdaptiveQuizItem( + item_id=f"QI-{uuid.uuid4().hex[:8]}", + type=qi.get("type", "multiple_choice"), + bloom_level=qi.get("bloom_level", "understanding"), + difficulty=qi.get("difficulty", "medium"), + question=qi.get("question", ""), + options=qi.get("options"), + correct_answer=qi.get("correct_answer", ""), + acceptable_range=qi.get("acceptable_range"), + solution_hint=qi.get("solution_hint", ""), + competency_code=qi.get("competency_code", request.topic_id), + curriculum_reference=qi.get("curriculum_reference", "DepEd SHS"), + )) + + return DiagnosticQuizResponse( + quiz_id=quiz_id, + topic_id=request.topic_id, + mastery_target_after="mastered" if mastery == "developing" else "developing" if mastery == "beginning" else "mastered", + items=items, + prev_score=prev_score, + difficulty_distribution=distribution, + ) + except Exception as e: + logger.error(f"Adaptive quiz generation error: {e}\n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Adaptive quiz error: {str(e)}") + + +# ─── Learning Path Endpoint ──────────────────────────────────── + +class DiagnosticLearningPathRequest(BaseModel): + student_id: str + strand: str = Field(default="STEM") + grade_level: str = Field(default="Grade 11") + + +class DiagnosticLearningPathTopic(BaseModel): + topic_id: str + title: str + mastery_level: str + estimated_minutes: int + + +class DiagnosticLearningPathResponse(BaseModel): + student_id: str + topics: List[DiagnosticLearningPathTopic] + total_estimated_hours: float + + +@app.post("/api/learning/path", response_model=DiagnosticLearningPathResponse) +async def generate_learning_path(request: DiagnosticLearningPathRequest): + """ + Generate personalized learning path based on student's diagnostic results. + """ + try: + if not HAS_FIREBASE_ADMIN or not firebase_firestore: + topics = [] + for tid, info in DEPD_TOPIC_REGISTRY.items(): + topics.append(DiagnosticLearningPathTopic( + topic_id=tid, + title=info["title"], + mastery_level="beginning", + estimated_minutes=20, + )) + return DiagnosticLearningPathResponse( + student_id=request.student_id, + topics=topics[:10], + total_estimated_hours=3.3, + ) + + db = firebase_firestore.client() + doc = db.collection("diagnosticResults").document(request.student_id).collection("attempts").limit(1).get() + + suggested_path = [] + if doc: + data = doc[0].to_dict() if doc else {} + suggested_path = data.get("riskProfile", {}).get("suggested_learning_path", []) + + path_topics = [] + if suggested_path: + for tid in suggested_path[:10]: + info = DEPD_TOPIC_REGISTRY.get(tid, {}) + path_topics.append(DiagnosticLearningPathTopic( + topic_id=tid, + title=info.get("title") or tid, + mastery_level="beginning", + estimated_minutes=20, + )) + else: + strand_topics = DEPD_ED_COMPETENCY_DOMAINS.get(request.strand, {}).get(request.grade_level, []) + for i, t in enumerate(strand_topics[:10]): + tid = f"NA-{(i+1):02d}-01" + path_topics.append(DiagnosticLearningPathTopic( + topic_id=tid, + title=t, + mastery_level="beginning", + estimated_minutes=20, + )) + + total_minutes = sum(t.estimated_minutes for t in path_topics) + + return DiagnosticLearningPathResponse( + student_id=request.student_id, + topics=path_topics, + total_estimated_hours=round(total_minutes / 60, 1), + ) + except Exception as e: + logger.error(f"Learning path generation error: {e}") + raise HTTPException(status_code=500, detail=f"Learning path error: {str(e)}") + + +# ─── Main ────────────────────────────────────────────────────── + +if __name__ == "__main__": + port = int(os.environ.get("PORT", 7860)) + uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)