Spaces:
Sleeping
Sleeping
File size: 2,687 Bytes
54bcfd7 8585229 54bcfd7 5396566 54bcfd7 cd8e991 54bcfd7 5396566 54bcfd7 8585229 54bcfd7 8585229 54bcfd7 5396566 54bcfd7 fca57f4 54bcfd7 5396566 54bcfd7 8585229 54bcfd7 8585229 54bcfd7 5396566 54bcfd7 8585229 54bcfd7 d4f5f1b 54bcfd7 8585229 54bcfd7 8585229 54bcfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
"""
FastAPI server inside Hugging Face Space
POST /predict -> zero-shot subject prediction + save to TiDB
"""
import os
import time
from contextlib import asynccontextmanager
import mysql.connector
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# ---------- load model ONCE ----------
MODEL_NAME = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"
LABELS = [
"Mathematics", "Physics", "Chemistry", "Biology",
"History", "Geography", "Literature", "Computer-Science"
]
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# load at startup
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
if torch.cuda.is_available():
model.cuda()
ml_models["tokenizer"] = tokenizer
ml_models["model"] = model
yield
# shutdown
ml_models.clear()
app = FastAPI(lifespan=lifespan)
# ---------- DB helper ----------
def get_conn():
return mysql.connector.connect(
host=os.getenv("DB_HOST"),
port=int(os.getenv("DB_PORT", 4000)),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASS"),
database=os.getenv("DB_NAME"),
ssl_ca=os.getenv("DB_SSL_CA_PATH") or None
)
# ---------- request schema ----------
class PredictRequest(BaseModel):
student_id: str
text: str
# ---------- API endpoint ----------
@app.post("/predict")
def predict(req: PredictRequest):
if not req.text.strip():
raise HTTPException(400, "Empty text")
tok = ml_models["tokenizer"](
req.text,
padding=True,
truncation=True,
return_tensors="pt"
)
if torch.cuda.is_available():
tok = {k: v.cuda() for k, v in tok.items()}
with torch.no_grad():
logits = ml_models["model"](**tok).logits
probs = torch.softmax(logits, dim=-1)[0]
idx = int(torch.argmax(probs))
subject = LABELS[idx]
# save to DB
try:
conn = get_conn()
cur = conn.cursor()
cur.execute(
"INSERT INTO log_table (student_id, input_sample, subject, prediction_time) "
"VALUES (%s, %s, %s, %s)",
(req.student_id, req.text, subject, time.strftime('%Y-%m-%d %H:%M:%S'))
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print("DB error:", e)
raise HTTPException(500, "DB write failed")
return {"subject": subject}
@app.get("/")
def root():
return {"message": "Subject predictor is running"} |