import joblib import os from fastapi import HTTPException MODELS_DIR = os.path.join(os.path.dirname(__file__), '..', 'models') _cache: dict = {} def _path(filename: str) -> str: return os.path.join(MODELS_DIR, filename) def _available(filename: str) -> bool: return os.path.exists(_path(filename)) def _load(filename: str, key: str) -> bool: if not _available(filename): return False try: _cache[key] = joblib.load(_path(filename)) return True except Exception as e: print(f"WARNING: could not load {filename}: {e}") return False def get_available_models() -> list[str]: available = [] if 'logistic' in _cache: available.append('logistic') if 'lgbm' in _cache: available.append('lgbm') if 'xgb' in _cache: available.append('xgb') return available def load_models() -> None: if 'vectorizer' not in _cache: if not _available('tfidf_vectorizer.pkl'): raise RuntimeError( "tfidf_vectorizer.pkl not found in models/. " "Run scripts/train_all_models.py first." ) _cache['vectorizer'] = joblib.load(_path('tfidf_vectorizer.pkl')) print("Vectorizer loaded.") if 'logistic' not in _cache: if _load('best_logistic.pkl', 'logistic'): print("Logistic model loaded.") if 'lgbm' not in _cache: if _load('best_lgbm.pkl', 'lgbm'): print("LightGBM model loaded.") if 'xgb' not in _cache: ok1 = _load('best_xgb.pkl', 'xgb') ok2 = _load('xgb_encoder.pkl', 'xgb_encoder') if ok1 and ok2: print("XGBoost model loaded.") elif ok1: del _cache['xgb'] available = get_available_models() if not available: raise RuntimeError( "No models loaded. Run scripts/train_all_models.py first." ) print(f"Models ready: {available}") def predict_claim(text: str, model_id: str = 'logistic') -> dict: if 'vectorizer' not in _cache: raise HTTPException(status_code=503, detail="Models not loaded.") vectorizer = _cache['vectorizer'] vector = vectorizer.transform([text]) if model_id == 'lgbm': if 'lgbm' not in _cache: raise HTTPException(status_code=503, detail="LightGBM model not available.") model = _cache['lgbm'] prediction = str(model.predict(vector)[0]) proba_values = model.predict_proba(vector)[0] probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)} elif model_id == 'xgb': if 'xgb' not in _cache: raise HTTPException(status_code=503, detail="XGBoost model not available.") model = _cache['xgb'] encoder = _cache['xgb_encoder'] pred_enc = model.predict(vector)[0] prediction = str(encoder.inverse_transform([pred_enc])[0]) proba_values = model.predict_proba(vector)[0] probabilities = {str(c): float(p) for c, p in zip(encoder.classes_, proba_values)} else: # logistic (default) if 'logistic' not in _cache: raise HTTPException(status_code=503, detail="Logistic model not available.") model = _cache['logistic'] prediction = str(model.predict(vector)[0]) proba_values = model.predict_proba(vector)[0] probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)} return {"prediction": prediction, "probabilities": probabilities}