Ahmad Hathim bin Ahmad Azman commited on
Commit
f3ce8a7
·
1 Parent(s): 10b33a5
Files changed (2) hide show
  1. app.py +20 -19
  2. model_inference.py +1 -0
app.py CHANGED
@@ -1,35 +1,34 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- from model_inference import load_model, predict_from_input, ensure_model_file
5
  import joblib
 
6
  from transformers import AutoTokenizer
 
7
 
8
- app = FastAPI(title="SQB Predictor API")
 
 
 
 
9
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
  allow_origins=["*"],
 
13
  allow_methods=["*"],
14
  allow_headers=["*"],
15
  )
16
 
17
- @app.on_event("startup")
18
- def load_all_resources():
19
- print("🚀 Downloading model and dependencies...")
20
-
21
- model_path = ensure_model_file("best_checkpoint_regression.pt")
22
- encoder_path = ensure_model_file("onehot_encoder.pkl")
23
- scaler_path = ensure_model_file("scaler.pkl")
24
-
25
- global model, device, encoder, scaler, tok_mcq, tok_clin
26
 
27
- model, device = load_model(model_path)
28
- encoder = joblib.load(encoder_path)
29
- scaler = joblib.load(scaler_path)
30
 
31
- tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
32
- tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
33
 
34
  class QuestionInput(BaseModel):
35
  StemText: str
@@ -43,11 +42,13 @@ class QuestionInput(BaseModel):
43
  BloomLevel: str
44
 
45
  @app.get("/health")
46
- def health_check():
47
  return {"status": "ok"}
48
 
49
  @app.post("/predict")
50
  def predict(input_data: QuestionInput):
51
- return predict_from_input(
52
- input_data.dict(), model, device, tok_mcq, tok_clin, encoder, scaler
 
53
  )
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ import torch
5
  import joblib
6
+ import os
7
  from transformers import AutoTokenizer
8
+ from model_inference import load_model, predict_from_input
9
 
10
+ # FIX: Set Hugging Face cache to a writable directory
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
13
+
14
+ app = FastAPI(title="Question Difficulty/Discrimination Predictor")
15
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
19
+ allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+ # ✅ Load model on startup
25
+ model, device = load_model()
 
 
 
 
 
 
 
26
 
27
+ encoder = joblib.load("assets/onehot_encoder.pkl")
28
+ scaler = joblib.load("assets/scaler.pkl")
 
29
 
30
+ tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
31
+ tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
32
 
33
  class QuestionInput(BaseModel):
34
  StemText: str
 
42
  BloomLevel: str
43
 
44
  @app.get("/health")
45
+ def health():
46
  return {"status": "ok"}
47
 
48
  @app.post("/predict")
49
  def predict(input_data: QuestionInput):
50
+ pred = predict_from_input(
51
+ input_data.dict(), model, device,
52
+ tok_mcq, tok_clin, encoder, scaler
53
  )
54
+ return pred
model_inference.py CHANGED
@@ -4,6 +4,7 @@ import textstat
4
  from utils.preprocess import compute_text_features
5
  from model_architecture import EnsembleBertBiLSTMRegressor
6
  from huggingface_hub import hf_hub_download
 
7
 
8
  HF_REPO = "hathimazman/sqb-predict"
9
 
 
4
  from utils.preprocess import compute_text_features
5
  from model_architecture import EnsembleBertBiLSTMRegressor
6
  from huggingface_hub import hf_hub_download
7
+ import os
8
 
9
  HF_REPO = "hathimazman/sqb-predict"
10