File size: 2,687 Bytes
54bcfd7
 
 
 
8585229
54bcfd7
 
5396566
54bcfd7
 
 
 
 
cd8e991
54bcfd7
 
 
 
 
 
5396566
54bcfd7
8585229
54bcfd7
 
 
 
 
 
 
 
 
 
 
 
 
8585229
54bcfd7
5396566
54bcfd7
 
 
 
fca57f4
54bcfd7
 
 
 
 
5396566
54bcfd7
 
8585229
 
 
54bcfd7
8585229
54bcfd7
 
 
 
 
 
 
 
5396566
54bcfd7
 
 
 
 
 
 
8585229
54bcfd7
 
 
 
 
d4f5f1b
54bcfd7
 
 
 
 
 
 
 
 
8585229
54bcfd7
8585229
54bcfd7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
FastAPI server inside Hugging Face Space
POST /predict  ->  zero-shot subject prediction + save to TiDB
"""
import os
import time
from contextlib import asynccontextmanager

import mysql.connector
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# ---------- load model ONCE ----------
MODEL_NAME = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
LABELS = [
    "Mathematics", "Physics", "Chemistry", "Biology",
    "History", "Geography", "Literature", "Computer-Science"
]

ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # load at startup
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
    model.eval()
    if torch.cuda.is_available():
        model.cuda()
    ml_models["tokenizer"] = tokenizer
    ml_models["model"] = model
    yield
    # shutdown
    ml_models.clear()

app = FastAPI(lifespan=lifespan)

# ---------- DB helper ----------
def get_conn():
    return mysql.connector.connect(
        host=os.getenv("DB_HOST"),
        port=int(os.getenv("DB_PORT", 4000)),
        user=os.getenv("DB_USER"),
        password=os.getenv("DB_PASS"),
        database=os.getenv("DB_NAME"),
        ssl_ca=os.getenv("DB_SSL_CA_PATH") or None
    )

# ---------- request schema ----------
class PredictRequest(BaseModel):
    student_id: str
    text: str

# ---------- API endpoint ----------
@app.post("/predict")
def predict(req: PredictRequest):
    if not req.text.strip():
        raise HTTPException(400, "Empty text")
    tok = ml_models["tokenizer"](
        req.text,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    if torch.cuda.is_available():
        tok = {k: v.cuda() for k, v in tok.items()}
    with torch.no_grad():
        logits = ml_models["model"](**tok).logits
        probs = torch.softmax(logits, dim=-1)[0]
        idx = int(torch.argmax(probs))
        subject = LABELS[idx]

    # save to DB
    try:
        conn = get_conn()
        cur = conn.cursor()
        cur.execute(
            "INSERT INTO log_table (student_id, input_sample, subject, prediction_time) "
            "VALUES (%s, %s, %s, %s)",
            (req.student_id, req.text, subject, time.strftime('%Y-%m-%d %H:%M:%S'))
        )
        conn.commit()
        cur.close()
        conn.close()
    except Exception as e:
        print("DB error:", e)
        raise HTTPException(500, "DB write failed")

    return {"subject": subject}

@app.get("/")
def root():
    return {"message": "Subject predictor is running"}