Spaces:
Running
Running
| import joblib | |
| import os | |
| from fastapi import HTTPException | |
| MODELS_DIR = os.path.join(os.path.dirname(__file__), '..', 'models') | |
| _cache: dict = {} | |
| def _path(filename: str) -> str: | |
| return os.path.join(MODELS_DIR, filename) | |
| def _available(filename: str) -> bool: | |
| return os.path.exists(_path(filename)) | |
| def _load(filename: str, key: str) -> bool: | |
| if not _available(filename): | |
| return False | |
| try: | |
| _cache[key] = joblib.load(_path(filename)) | |
| return True | |
| except Exception as e: | |
| print(f"WARNING: could not load {filename}: {e}") | |
| return False | |
| def get_available_models() -> list[str]: | |
| available = [] | |
| if 'logistic' in _cache: available.append('logistic') | |
| if 'lgbm' in _cache: available.append('lgbm') | |
| if 'xgb' in _cache: available.append('xgb') | |
| return available | |
| def load_models() -> None: | |
| if 'vectorizer' not in _cache: | |
| if not _available('tfidf_vectorizer.pkl'): | |
| raise RuntimeError( | |
| "tfidf_vectorizer.pkl not found in models/. " | |
| "Run scripts/train_all_models.py first." | |
| ) | |
| _cache['vectorizer'] = joblib.load(_path('tfidf_vectorizer.pkl')) | |
| print("Vectorizer loaded.") | |
| if 'logistic' not in _cache: | |
| if _load('best_logistic.pkl', 'logistic'): | |
| print("Logistic model loaded.") | |
| if 'lgbm' not in _cache: | |
| if _load('best_lgbm.pkl', 'lgbm'): | |
| print("LightGBM model loaded.") | |
| if 'xgb' not in _cache: | |
| ok1 = _load('best_xgb.pkl', 'xgb') | |
| ok2 = _load('xgb_encoder.pkl', 'xgb_encoder') | |
| if ok1 and ok2: | |
| print("XGBoost model loaded.") | |
| elif ok1: | |
| del _cache['xgb'] | |
| available = get_available_models() | |
| if not available: | |
| raise RuntimeError( | |
| "No models loaded. Run scripts/train_all_models.py first." | |
| ) | |
| print(f"Models ready: {available}") | |
| def predict_claim(text: str, model_id: str = 'logistic') -> dict: | |
| if 'vectorizer' not in _cache: | |
| raise HTTPException(status_code=503, detail="Models not loaded.") | |
| vectorizer = _cache['vectorizer'] | |
| vector = vectorizer.transform([text]) | |
| if model_id == 'lgbm': | |
| if 'lgbm' not in _cache: | |
| raise HTTPException(status_code=503, detail="LightGBM model not available.") | |
| model = _cache['lgbm'] | |
| prediction = str(model.predict(vector)[0]) | |
| proba_values = model.predict_proba(vector)[0] | |
| probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)} | |
| elif model_id == 'xgb': | |
| if 'xgb' not in _cache: | |
| raise HTTPException(status_code=503, detail="XGBoost model not available.") | |
| model = _cache['xgb'] | |
| encoder = _cache['xgb_encoder'] | |
| pred_enc = model.predict(vector)[0] | |
| prediction = str(encoder.inverse_transform([pred_enc])[0]) | |
| proba_values = model.predict_proba(vector)[0] | |
| probabilities = {str(c): float(p) for c, p in zip(encoder.classes_, proba_values)} | |
| else: # logistic (default) | |
| if 'logistic' not in _cache: | |
| raise HTTPException(status_code=503, detail="Logistic model not available.") | |
| model = _cache['logistic'] | |
| prediction = str(model.predict(vector)[0]) | |
| proba_values = model.predict_proba(vector)[0] | |
| probabilities = {str(c): float(p) for c, p in zip(model.classes_, proba_values)} | |
| return {"prediction": prediction, "probabilities": probabilities} | |