ai-nids-backend / utils /model_selector.py
CodebaseAi's picture
Fixed7
ee1c180
import os
import joblib
import threading
import traceback
from huggingface_hub import hf_hub_download
import sklearn.utils
# --- CONFIGURATION ---
HF_REPO_ID = "CodebaseAi/netraids-ml-models" # Replace with your actual public repo ID
# ---------------------
ACTIVE_MODEL = "bcc"
_ACTIVE_LOCK = threading.Lock()
_MODEL_CACHE = {}
# --- THE "BULLETPROOF" SKLEARN PATCH ---
# We must do this BEFORE any other ML imports
import sklearn.utils
try:
import sklearn.utils._column_transformer as ct_utils
sklearn.utils._get_column_indices = ct_utils._get_column_indices
print("[Patch] Successfully injected _get_column_indices")
except Exception as e:
# If the above fails, we define a dummy function to stop the crash
def _get_column_indices(X, key):
from sklearn.utils._column_transformer import _get_column_indices as gci
return gci(X, key)
sklearn.utils._get_column_indices = _get_column_indices
print(f"[Patch] Manual injection fallback used: {e}")
# Patch for parse_version
if not hasattr(sklearn.utils, 'parse_version'):
try:
from packaging import version
sklearn.utils.parse_version = version.parse
print("[Patch] Successfully injected parse_version")
except ImportError:
print("[Patch] 'packaging' library missing. Ensure it is in requirements.txt")
# ---------------------------------------
# --------------------------------------------
# 1. FIXED PATH LOGIC:
# __file__ is /app/utils/model_selector.py
# dirname(__file__) is /app/utils
# dirname(dirname(...)) is /app (the ROOT)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ML_DIR = os.path.join(BASE_DIR, "ml_models")
# Ensure the local ml_models directory exists for caching
if not os.path.exists(ML_DIR):
os.makedirs(ML_DIR, exist_ok=True)
print(f"[model_selector] ROOT BASE_DIR: {BASE_DIR}")
print(f"[model_selector] ML_DIR: {ML_DIR}")
def _get_model_path(filename):
"""
First looks in the local 'ml_models' folder.
If not found, downloads from the public Hugging Face Hub.
"""
local_path = os.path.join(ML_DIR, filename)
# 1. Check if the file is already there
if os.path.exists(local_path):
return local_path
# 2. Download from Hub if missing
try:
print(f"[model_selector] {filename} not found locally. Downloading from Hub...")
# We specify local_dir to force it into our ml_models folder
downloaded_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=ML_DIR
)
return downloaded_path
except Exception as e:
print(f"[model_selector] ERROR: Could not find/download {filename}: {e}")
return None
def _try_load(filename):
path = _get_model_path(filename)
if not path or not os.path.exists(path):
print(f"[model_selector] SKIP: {filename} path invalid.")
return None
try:
# Check if file size is > 0 before loading
if os.path.getsize(path) == 0:
print(f"[model_selector] ERROR: {filename} is an empty file.")
return None
print(f"[model_selector] Attempting joblib.load for {filename}")
return joblib.load(path)
except Exception as e:
print(f"[model_selector] CRITICAL FAILED to load {filename}")
print(traceback.format_exc()) # This will show exactly why in HF logs
return None
def load_model(model_key):
if model_key in _MODEL_CACHE:
return _MODEL_CACHE[model_key]
if model_key == "bcc":
_MODEL_CACHE["bcc"] = {
"model": _try_load("realtime_model.pkl"),
"scaler": _try_load("realtime_scaler.pkl"),
"encoder": _try_load("realtime_encoder.pkl")
}
return _MODEL_CACHE["bcc"]
if model_key == "cicids":
# It will look for your RF files in the Hub
_MODEL_CACHE["cicids"] = {
"model": _try_load("rf_pipeline.joblib"),
"artifacts": _try_load("training_artifacts.joblib")
}
return _MODEL_CACHE["cicids"]
raise ValueError(f"Unknown model_key: {model_key}")
def set_active_model(key: str):
global ACTIVE_MODEL
with _ACTIVE_LOCK:
ACTIVE_MODEL = key
print(f"[model_selector] ACTIVE_MODEL set to: {ACTIVE_MODEL}")
def get_active_model():
return ACTIVE_MODEL