sqb-predict-api / app.py
Ahmad Hathim bin Ahmad Azman
fixed os
f3ce8a7
raw
history blame
1.4 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import joblib
import os
from transformers import AutoTokenizer
from model_inference import load_model, predict_from_input
# βœ… FIX: Set Hugging Face cache to a writable directory
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_cache"
app = FastAPI(title="Question Difficulty/Discrimination Predictor")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# βœ… Load model on startup
model, device = load_model()
encoder = joblib.load("assets/onehot_encoder.pkl")
scaler = joblib.load("assets/scaler.pkl")
tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
class QuestionInput(BaseModel):
StemText: str
LeadIn: str
OptionA: str
OptionB: str
OptionC: str
OptionD: str
DepartmentName: str
CourseName: str
BloomLevel: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(input_data: QuestionInput):
pred = predict_from_input(
input_data.dict(), model, device,
tok_mcq, tok_clin, encoder, scaler
)
return pred