ALM_LLM / tools /explain_tool.py
AshenH's picture
Update tools/explain_tool.py
e4818d5 verified
# space/tools/explain_tool.py
import os
import io
import json
import base64
import logging
from typing import Dict, Optional
import shap
import pandas as pd
import matplotlib
matplotlib.use('Agg') # Non-interactive backend
import matplotlib.pyplot as plt
import joblib
from huggingface_hub import hf_hub_download
from utils.config import AppConfig
from utils.tracing import Tracer
logger = logging.getLogger(__name__)
# Constants
MAX_SAMPLE_SIZE = 1000
MIN_SAMPLE_SIZE = 10
DEFAULT_SAMPLE_SIZE = 500
MAX_IMAGE_SIZE_MB = 5
class ExplainToolError(Exception):
"""Custom exception for explanation tool errors."""
pass
class ExplainTool:
"""
Generates SHAP-based model explanations with global visualizations.
CPU-friendly with sampling for large datasets.
"""
def __init__(self, cfg: AppConfig, tracer: Tracer):
self.cfg = cfg
self.tracer = tracer
self._model = None
self._feature_order = None
logger.info("ExplainTool initialized (lazy loading)")
def _ensure_model(self):
"""Lazy load model and metadata from HuggingFace."""
if self._model is not None:
return
try:
token = os.getenv("HF_TOKEN")
repo = self.cfg.hf_model_repo
if not repo:
raise ExplainToolError("HF_MODEL_REPO not configured")
logger.info(f"Loading model for explanations from: {repo}")
# Download and load model
try:
model_path = hf_hub_download(
repo_id=repo,
filename="model.pkl",
token=token
)
self._model = joblib.load(model_path)
logger.info(f"Model loaded: {type(self._model).__name__}")
except Exception as e:
raise ExplainToolError(f"Failed to load model: {e}") from e
# Load feature metadata
try:
meta_path = hf_hub_download(
repo_id=repo,
filename="feature_metadata.json",
token=token
)
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f) or {}
self._feature_order = meta.get("feature_order")
logger.info(f"Loaded feature order: {len(self._feature_order or [])} features")
except Exception as e:
logger.warning(f"Could not load feature metadata: {e}")
self._feature_order = None
except ExplainToolError:
raise
except Exception as e:
raise ExplainToolError(f"Model initialization failed: {e}") from e
def _validate_data(self, df: pd.DataFrame) -> tuple[bool, str]:
"""
Validate input dataframe.
Returns (is_valid, error_message).
"""
if df is None or df.empty:
return False, "Input dataframe is empty"
if len(df.columns) == 0:
return False, "Dataframe has no columns"
return True, ""
def _prepare_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Prepare feature matrix for SHAP analysis.
Selects and orders features according to model expectations.
"""
if self._feature_order:
# Use specified feature order
available_features = [col for col in self._feature_order if col in df.columns]
missing_features = [col for col in self._feature_order if col not in df.columns]
if missing_features:
logger.warning(
f"Missing {len(missing_features)} features for explanation: "
f"{missing_features[:5]}"
)
if not available_features:
raise ExplainToolError(
f"No required features found in dataframe. "
f"Required: {self._feature_order}, "
f"Available: {list(df.columns)}"
)
X = df[available_features].copy()
logger.info(f"Using {len(available_features)} features for explanation")
else:
# Use all columns
X = df.copy()
logger.warning("No feature order specified - using all columns")
# Remove non-numeric columns
numeric_cols = X.select_dtypes(include=['number']).columns
if len(numeric_cols) < len(X.columns):
dropped = set(X.columns) - set(numeric_cols)
logger.warning(f"Dropping {len(dropped)} non-numeric columns: {list(dropped)[:5]}")
X = X[numeric_cols]
if X.empty or len(X.columns) == 0:
raise ExplainToolError("No numeric features available for explanation")
return X
def _sample_data(self, X: pd.DataFrame, sample_size: int = DEFAULT_SAMPLE_SIZE) -> pd.DataFrame:
"""
Sample data for SHAP analysis to keep computation manageable.
"""
n = len(X)
if n <= MIN_SAMPLE_SIZE:
logger.info(f"Using all {n} rows (below minimum sample size)")
return X
# Determine sample size
target_size = min(sample_size, MAX_SAMPLE_SIZE)
target_size = max(target_size, MIN_SAMPLE_SIZE)
if n <= target_size:
logger.info(f"Using all {n} rows (below target sample size)")
return X
# Stratified sampling if possible
try:
sample = X.sample(n=target_size, random_state=42)
logger.info(f"Sampled {target_size} rows from {n} total")
return sample
except Exception as e:
logger.warning(f"Sampling failed: {e}, using head()")
return X.head(target_size)
@staticmethod
def _to_data_uri(fig) -> str:
"""
Convert matplotlib figure to base64 data URI.
Includes size validation.
"""
try:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
plt.close(fig)
buf.seek(0)
# Check size
size_mb = len(buf.getvalue()) / (1024 * 1024)
if size_mb > MAX_IMAGE_SIZE_MB:
logger.warning(f"Generated image is large: {size_mb:.2f} MB")
data_uri = "data:image/png;base64," + base64.b64encode(buf.read()).decode()
logger.debug(f"Generated data URI of size: {len(data_uri)} chars")
return data_uri
except Exception as e:
logger.error(f"Failed to convert figure to data URI: {e}")
raise ExplainToolError(f"Image conversion failed: {e}") from e
def _generate_shap_values(self, X: pd.DataFrame) -> shap.Explanation:
"""
Generate SHAP values for the sample.
"""
try:
logger.info("Creating SHAP explainer...")
explainer = shap.Explainer(self._model, X)
logger.info("Computing SHAP values...")
shap_values = explainer(X)
logger.info(f"SHAP values computed: shape={shap_values.values.shape}")
return shap_values
except Exception as e:
raise ExplainToolError(f"SHAP computation failed: {e}") from e
def _create_bar_plot(self, shap_values: shap.Explanation) -> str:
"""Create global feature importance bar plot."""
try:
logger.info("Creating bar plot...")
fig = plt.figure(figsize=(10, 6))
shap.plots.bar(shap_values, show=False, max_display=20)
plt.title("Feature Importance (Global)", fontsize=14, pad=20)
plt.xlabel("Mean |SHAP value|", fontsize=12)
plt.tight_layout()
uri = self._to_data_uri(fig)
logger.info("Bar plot created successfully")
return uri
except Exception as e:
logger.error(f"Bar plot creation failed: {e}")
# Return empty data URI rather than failing completely
return ""
def _create_beeswarm_plot(self, shap_values: shap.Explanation) -> str:
"""Create beeswarm plot showing feature effects."""
try:
logger.info("Creating beeswarm plot...")
fig = plt.figure(figsize=(10, 8))
shap.plots.beeswarm(shap_values, show=False, max_display=20)
plt.title("Feature Effects Distribution", fontsize=14, pad=20)
plt.tight_layout()
uri = self._to_data_uri(fig)
logger.info("Beeswarm plot created successfully")
return uri
except Exception as e:
logger.error(f"Beeswarm plot creation failed: {e}")
return ""
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]:
"""
Generate SHAP explanations for input data.
Args:
df: Input dataframe with features
Returns:
Dictionary mapping plot names to base64 data URIs
Raises:
ExplainToolError: If explanation generation fails
"""
try:
# Validate input
is_valid, error_msg = self._validate_data(df)
if not is_valid:
logger.warning(f"Invalid input: {error_msg}")
return {}
# Ensure model is loaded
self._ensure_model()
# Prepare features
X = self._prepare_features(df)
logger.info(f"Prepared features: {X.shape}")
# Sample data for efficiency
sample = self._sample_data(X)
# Generate SHAP values
shap_values = self._generate_shap_values(sample)
# Create visualizations
result = {}
# Bar plot (feature importance)
bar_uri = self._create_bar_plot(shap_values)
if bar_uri:
result["global_bar"] = bar_uri
# Beeswarm plot (feature effects)
bee_uri = self._create_beeswarm_plot(shap_values)
if bee_uri:
result["beeswarm"] = bee_uri
# Log success
logger.info(f"Generated {len(result)} explanation visualizations")
if self.tracer:
self.tracer.trace_event("explain", {
"rows": len(sample),
"features": len(X.columns),
"visualizations": len(result)
})
return result
except ExplainToolError:
raise
except Exception as e:
error_msg = f"Explanation generation failed: {str(e)}"
logger.error(error_msg)
if self.tracer:
self.tracer.trace_event("explain_error", {"error": error_msg})
raise ExplainToolError(error_msg) from e