Triple-R-Checker / api /inference.py
=Apyhtml20
Add RAG module: PostgreSQL + pgvector in same container
45fa780
Raw
History Blame Contribute Delete
3.52 kB
import joblib
import os
from fastapi import HTTPException
MODELS_DIR = os.path.join(os.path.dirname(__file__), '..', 'models')
_cache: dict = {}
def _path(filename: str) -> str:
return os.path.join(MODELS_DIR, filename)
def _available(filename: str) -> bool:
return os.path.exists(_path(filename))
def _load(filename: str, key: str) -> bool:
if not _available(filename):
return False
try:
_cache[key] = joblib.load(_path(filename))
return True
except Exception as e:
print(f"WARNING: could not load {filename}: {e}")
return False
def get_available_models() -> list[str]:
available = []
if 'logistic' in _cache: available.append('logistic')
if 'lgbm' in _cache: available.append('lgbm')
if 'xgb' in _cache: available.append('xgb')
return available
def load_models() -> None:
if 'vectorizer' not in _cache:
if not _available('tfidf_vectorizer.pkl'):
raise RuntimeError(
"tfidf_vectorizer.pkl not found in models/. "
"Run scripts/train_all_models.py first."
)
_cache['vectorizer'] = joblib.load(_path('tfidf_vectorizer.pkl'))
print("Vectorizer loaded.")
if 'logistic' not in _cache:
if _load('best_logistic.pkl', 'logistic'):
print("Logistic model loaded.")
if 'lgbm' not in _cache:
if _load('best_lgbm.pkl', 'lgbm'):
print("LightGBM model loaded.")
if 'xgb' not in _cache:
ok1 = _load('best_xgb.pkl', 'xgb')
ok2 = _load('xgb_encoder.pkl', 'xgb_encoder')
if ok1 and ok2:
print("XGBoost model loaded.")
elif ok1:
del _cache['xgb']
available = get_available_models()
if not available:
raise RuntimeError(
"No models loaded. Run scripts/train_all_models.py first."
)
print(f"Models ready: {available}")
def predict_claim(text: str, model_id: str = 'logistic') -> dict:
if 'vectorizer' not in _cache:
raise HTTPException(status_code=503, detail="Models not loaded.")
vectorizer = _cache['vectorizer']
vector = vectorizer.transform([text])
if model_id == 'lgbm':
if 'lgbm' not in _cache:
raise HTTPException(status_code=503, detail="LightGBM model not available.")
model = _cache['lgbm']
prediction = str(model.predict(vector)[0])
proba_values = model.predict_proba(vector)[0]
probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)}
elif model_id == 'xgb':
if 'xgb' not in _cache:
raise HTTPException(status_code=503, detail="XGBoost model not available.")
model = _cache['xgb']
encoder = _cache['xgb_encoder']
pred_enc = model.predict(vector)[0]
prediction = str(encoder.inverse_transform([pred_enc])[0])
proba_values = model.predict_proba(vector)[0]
probabilities = {str(c): float(p) for c, p in zip(encoder.classes_, proba_values)}
else: # logistic (default)
if 'logistic' not in _cache:
raise HTTPException(status_code=503, detail="Logistic model not available.")
model = _cache['logistic']
prediction = str(model.predict(vector)[0])
proba_values = model.predict_proba(vector)[0]
probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)}
return {"prediction": prediction, "probabilities": probabilities}