|
|
|
|
|
import os |
|
|
import io |
|
|
import json |
|
|
import base64 |
|
|
import logging |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import shap |
|
|
import pandas as pd |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import joblib |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MAX_SAMPLE_SIZE = 1000 |
|
|
MIN_SAMPLE_SIZE = 10 |
|
|
DEFAULT_SAMPLE_SIZE = 500 |
|
|
MAX_IMAGE_SIZE_MB = 5 |
|
|
|
|
|
|
|
|
class ExplainToolError(Exception): |
|
|
"""Custom exception for explanation tool errors.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ExplainTool: |
|
|
""" |
|
|
Generates SHAP-based model explanations with global visualizations. |
|
|
CPU-friendly with sampling for large datasets. |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self._model = None |
|
|
self._feature_order = None |
|
|
|
|
|
logger.info("ExplainTool initialized (lazy loading)") |
|
|
|
|
|
def _ensure_model(self): |
|
|
"""Lazy load model and metadata from HuggingFace.""" |
|
|
if self._model is not None: |
|
|
return |
|
|
|
|
|
try: |
|
|
token = os.getenv("HF_TOKEN") |
|
|
repo = self.cfg.hf_model_repo |
|
|
|
|
|
if not repo: |
|
|
raise ExplainToolError("HF_MODEL_REPO not configured") |
|
|
|
|
|
logger.info(f"Loading model for explanations from: {repo}") |
|
|
|
|
|
|
|
|
try: |
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo, |
|
|
filename="model.pkl", |
|
|
token=token |
|
|
) |
|
|
self._model = joblib.load(model_path) |
|
|
logger.info(f"Model loaded: {type(self._model).__name__}") |
|
|
except Exception as e: |
|
|
raise ExplainToolError(f"Failed to load model: {e}") from e |
|
|
|
|
|
|
|
|
try: |
|
|
meta_path = hf_hub_download( |
|
|
repo_id=repo, |
|
|
filename="feature_metadata.json", |
|
|
token=token |
|
|
) |
|
|
with open(meta_path, "r", encoding="utf-8") as f: |
|
|
meta = json.load(f) or {} |
|
|
|
|
|
self._feature_order = meta.get("feature_order") |
|
|
logger.info(f"Loaded feature order: {len(self._feature_order or [])} features") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load feature metadata: {e}") |
|
|
self._feature_order = None |
|
|
|
|
|
except ExplainToolError: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise ExplainToolError(f"Model initialization failed: {e}") from e |
|
|
|
|
|
def _validate_data(self, df: pd.DataFrame) -> tuple[bool, str]: |
|
|
""" |
|
|
Validate input dataframe. |
|
|
Returns (is_valid, error_message). |
|
|
""" |
|
|
if df is None or df.empty: |
|
|
return False, "Input dataframe is empty" |
|
|
|
|
|
if len(df.columns) == 0: |
|
|
return False, "Dataframe has no columns" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
def _prepare_features(self, df: pd.DataFrame) -> pd.DataFrame: |
|
|
""" |
|
|
Prepare feature matrix for SHAP analysis. |
|
|
Selects and orders features according to model expectations. |
|
|
""" |
|
|
if self._feature_order: |
|
|
|
|
|
available_features = [col for col in self._feature_order if col in df.columns] |
|
|
missing_features = [col for col in self._feature_order if col not in df.columns] |
|
|
|
|
|
if missing_features: |
|
|
logger.warning( |
|
|
f"Missing {len(missing_features)} features for explanation: " |
|
|
f"{missing_features[:5]}" |
|
|
) |
|
|
|
|
|
if not available_features: |
|
|
raise ExplainToolError( |
|
|
f"No required features found in dataframe. " |
|
|
f"Required: {self._feature_order}, " |
|
|
f"Available: {list(df.columns)}" |
|
|
) |
|
|
|
|
|
X = df[available_features].copy() |
|
|
logger.info(f"Using {len(available_features)} features for explanation") |
|
|
else: |
|
|
|
|
|
X = df.copy() |
|
|
logger.warning("No feature order specified - using all columns") |
|
|
|
|
|
|
|
|
numeric_cols = X.select_dtypes(include=['number']).columns |
|
|
if len(numeric_cols) < len(X.columns): |
|
|
dropped = set(X.columns) - set(numeric_cols) |
|
|
logger.warning(f"Dropping {len(dropped)} non-numeric columns: {list(dropped)[:5]}") |
|
|
X = X[numeric_cols] |
|
|
|
|
|
if X.empty or len(X.columns) == 0: |
|
|
raise ExplainToolError("No numeric features available for explanation") |
|
|
|
|
|
return X |
|
|
|
|
|
def _sample_data(self, X: pd.DataFrame, sample_size: int = DEFAULT_SAMPLE_SIZE) -> pd.DataFrame: |
|
|
""" |
|
|
Sample data for SHAP analysis to keep computation manageable. |
|
|
""" |
|
|
n = len(X) |
|
|
|
|
|
if n <= MIN_SAMPLE_SIZE: |
|
|
logger.info(f"Using all {n} rows (below minimum sample size)") |
|
|
return X |
|
|
|
|
|
|
|
|
target_size = min(sample_size, MAX_SAMPLE_SIZE) |
|
|
target_size = max(target_size, MIN_SAMPLE_SIZE) |
|
|
|
|
|
if n <= target_size: |
|
|
logger.info(f"Using all {n} rows (below target sample size)") |
|
|
return X |
|
|
|
|
|
|
|
|
try: |
|
|
sample = X.sample(n=target_size, random_state=42) |
|
|
logger.info(f"Sampled {target_size} rows from {n} total") |
|
|
return sample |
|
|
except Exception as e: |
|
|
logger.warning(f"Sampling failed: {e}, using head()") |
|
|
return X.head(target_size) |
|
|
|
|
|
@staticmethod |
|
|
def _to_data_uri(fig) -> str: |
|
|
""" |
|
|
Convert matplotlib figure to base64 data URI. |
|
|
Includes size validation. |
|
|
""" |
|
|
try: |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
|
|
|
|
|
|
size_mb = len(buf.getvalue()) / (1024 * 1024) |
|
|
if size_mb > MAX_IMAGE_SIZE_MB: |
|
|
logger.warning(f"Generated image is large: {size_mb:.2f} MB") |
|
|
|
|
|
data_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode() |
|
|
logger.debug(f"Generated data URI of size: {len(data_uri)} chars") |
|
|
|
|
|
return data_uri |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to convert figure to data URI: {e}") |
|
|
raise ExplainToolError(f"Image conversion failed: {e}") from e |
|
|
|
|
|
def _generate_shap_values(self, X: pd.DataFrame) -> shap.Explanation: |
|
|
""" |
|
|
Generate SHAP values for the sample. |
|
|
""" |
|
|
try: |
|
|
logger.info("Creating SHAP explainer...") |
|
|
explainer = shap.Explainer(self._model, X) |
|
|
|
|
|
logger.info("Computing SHAP values...") |
|
|
shap_values = explainer(X) |
|
|
|
|
|
logger.info(f"SHAP values computed: shape={shap_values.values.shape}") |
|
|
return shap_values |
|
|
|
|
|
except Exception as e: |
|
|
raise ExplainToolError(f"SHAP computation failed: {e}") from e |
|
|
|
|
|
def _create_bar_plot(self, shap_values: shap.Explanation) -> str: |
|
|
"""Create global feature importance bar plot.""" |
|
|
try: |
|
|
logger.info("Creating bar plot...") |
|
|
fig = plt.figure(figsize=(10, 6)) |
|
|
shap.plots.bar(shap_values, show=False, max_display=20) |
|
|
plt.title("Feature Importance (Global)", fontsize=14, pad=20) |
|
|
plt.xlabel("Mean |SHAP value|", fontsize=12) |
|
|
plt.tight_layout() |
|
|
|
|
|
uri = self._to_data_uri(fig) |
|
|
logger.info("Bar plot created successfully") |
|
|
return uri |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Bar plot creation failed: {e}") |
|
|
|
|
|
return "" |
|
|
|
|
|
def _create_beeswarm_plot(self, shap_values: shap.Explanation) -> str: |
|
|
"""Create beeswarm plot showing feature effects.""" |
|
|
try: |
|
|
logger.info("Creating beeswarm plot...") |
|
|
fig = plt.figure(figsize=(10, 8)) |
|
|
shap.plots.beeswarm(shap_values, show=False, max_display=20) |
|
|
plt.title("Feature Effects Distribution", fontsize=14, pad=20) |
|
|
plt.tight_layout() |
|
|
|
|
|
uri = self._to_data_uri(fig) |
|
|
logger.info("Beeswarm plot created successfully") |
|
|
return uri |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Beeswarm plot creation failed: {e}") |
|
|
return "" |
|
|
|
|
|
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]: |
|
|
""" |
|
|
Generate SHAP explanations for input data. |
|
|
|
|
|
Args: |
|
|
df: Input dataframe with features |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping plot names to base64 data URIs |
|
|
|
|
|
Raises: |
|
|
ExplainToolError: If explanation generation fails |
|
|
""" |
|
|
try: |
|
|
|
|
|
is_valid, error_msg = self._validate_data(df) |
|
|
if not is_valid: |
|
|
logger.warning(f"Invalid input: {error_msg}") |
|
|
return {} |
|
|
|
|
|
|
|
|
self._ensure_model() |
|
|
|
|
|
|
|
|
X = self._prepare_features(df) |
|
|
logger.info(f"Prepared features: {X.shape}") |
|
|
|
|
|
|
|
|
sample = self._sample_data(X) |
|
|
|
|
|
|
|
|
shap_values = self._generate_shap_values(sample) |
|
|
|
|
|
|
|
|
result = {} |
|
|
|
|
|
|
|
|
bar_uri = self._create_bar_plot(shap_values) |
|
|
if bar_uri: |
|
|
result["global_bar"] = bar_uri |
|
|
|
|
|
|
|
|
bee_uri = self._create_beeswarm_plot(shap_values) |
|
|
if bee_uri: |
|
|
result["beeswarm"] = bee_uri |
|
|
|
|
|
|
|
|
logger.info(f"Generated {len(result)} explanation visualizations") |
|
|
|
|
|
if self.tracer: |
|
|
self.tracer.trace_event("explain", { |
|
|
"rows": len(sample), |
|
|
"features": len(X.columns), |
|
|
"visualizations": len(result) |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
except ExplainToolError: |
|
|
raise |
|
|
except Exception as e: |
|
|
error_msg = f"Explanation generation failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
if self.tracer: |
|
|
self.tracer.trace_event("explain_error", {"error": error_msg}) |
|
|
raise ExplainToolError(error_msg) from e |