import torch import numpy as np import logging import plotly.graph_objects as go from typing import Tuple, Dict # Advanced analysis imports import shap import lime from lime.lime_text import LimeTextExplainer from config import config from models import ModelManager, handle_errors logger = logging.getLogger(__name__) class AdvancedAnalysisEngine: """Advanced analysis using SHAP and LIME with FIXED implementation""" def __init__(self): self.model_manager = ModelManager() def create_prediction_function(self, model, tokenizer, device): """Create FIXED prediction function for SHAP/LIME""" def predict_proba(texts): # Ensure texts is a list if isinstance(texts, str): texts = [texts] elif isinstance(texts, np.ndarray): texts = texts.tolist() # Convert all elements to strings texts = [str(text) for text in texts] results = [] batch_size = 16 # Process in smaller batches for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] try: with torch.no_grad(): # Tokenize batch inputs = tokenizer( batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=config.MAX_TEXT_LENGTH ).to(device) # Batch inference outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() results.extend(probs) except Exception as e: logger.error(f"Prediction batch failed: {e}") # Return neutral predictions for failed batch batch_size_actual = len(batch_texts) if hasattr(model.config, 'num_labels') and model.config.num_labels == 3: neutral_probs = np.array([[0.33, 0.34, 0.33]] * batch_size_actual) else: neutral_probs = np.array([[0.5, 0.5]] * batch_size_actual) results.extend(neutral_probs) return np.array(results) return predict_proba @handle_errors(default_return=("Analysis failed", None, None)) def analyze_with_shap(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]: """FIXED SHAP analysis implementation""" if not text.strip(): return "Please enter text for analysis", None, {} # Detect language and get model if language == 'auto': detected_lang = self.model_manager.detect_language(text) else: detected_lang = language model, tokenizer = self.model_manager.get_model(detected_lang) try: # Create FIXED prediction function predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device) # Test the prediction function first test_pred = predict_fn([text]) if test_pred is None or len(test_pred) == 0: return "Prediction function test failed", None, {} # Use SHAP Text Explainer instead of generic Explainer explainer = shap.Explainer(predict_fn, masker=shap.maskers.Text(tokenizer)) # Get SHAP values with proper text input shap_values = explainer([text], max_evals=num_samples) # Extract data safely if hasattr(shap_values, 'data') and hasattr(shap_values, 'values'): tokens = shap_values.data[0] if len(shap_values.data) > 0 else [] values = shap_values.values[0] if len(shap_values.values) > 0 else [] else: return "SHAP values extraction failed", None, {} if len(tokens) == 0 or len(values) == 0: return "No tokens or values extracted from SHAP", None, {} # Handle multi-dimensional values if len(values.shape) > 1: # Use positive class values (last column for 3-class, second for 2-class) pos_values = values[:, -1] if values.shape[1] >= 2 else values[:, 0] else: pos_values = values # Ensure we have matching lengths min_len = min(len(tokens), len(pos_values)) tokens = tokens[:min_len] pos_values = pos_values[:min_len] # Create visualization fig = go.Figure() colors = ['red' if v < 0 else 'green' for v in pos_values] fig.add_trace(go.Bar( x=list(range(len(tokens))), y=pos_values, text=tokens, textposition='outside', marker_color=colors, name='SHAP Values', hovertemplate='%{text}
SHAP Value: %{y:.4f}' )) fig.update_layout( title=f"SHAP Analysis - Token Importance (Samples: {num_samples})", xaxis_title="Token Index", yaxis_title="SHAP Value", height=500, xaxis=dict(tickmode='array', tickvals=list(range(len(tokens))), ticktext=tokens) ) # Create analysis summary analysis_data = { 'method': 'SHAP', 'language': detected_lang, 'total_tokens': len(tokens), 'samples_used': num_samples, 'positive_influence': sum(1 for v in pos_values if v > 0), 'negative_influence': sum(1 for v in pos_values if v < 0), 'most_important_tokens': [(str(tokens[i]), float(pos_values[i])) for i in np.argsort(np.abs(pos_values))[-5:]] } summary_text = f""" **SHAP Analysis Results:** - **Language:** {detected_lang.upper()} - **Total Tokens:** {analysis_data['total_tokens']} - **Samples Used:** {num_samples} - **Positive Influence Tokens:** {analysis_data['positive_influence']} - **Negative Influence Tokens:** {analysis_data['negative_influence']} - **Most Important Tokens:** {', '.join([f"{token}({score:.3f})" for token, score in analysis_data['most_important_tokens']])} - **Status:** SHAP analysis completed successfully """ return summary_text, fig, analysis_data except Exception as e: logger.error(f"SHAP analysis failed: {e}") error_msg = f""" **SHAP Analysis Failed:** - **Error:** {str(e)} - **Language:** {detected_lang.upper()} - **Suggestion:** Try with a shorter text or reduce number of samples **Common fixes:** - Reduce sample size to 50-100 - Use shorter input text (< 200 words) - Check if model supports the text language """ return error_msg, None, {} @handle_errors(default_return=("Analysis failed", None, None)) def analyze_with_lime(self, text: str, language: str = 'auto', num_samples: int = 100) -> Tuple[str, go.Figure, Dict]: """FIXED LIME analysis implementation - Bug Fix for mode parameter""" if not text.strip(): return "Please enter text for analysis", None, {} # Detect language and get model if language == 'auto': detected_lang = self.model_manager.detect_language(text) else: detected_lang = language model, tokenizer = self.model_manager.get_model(detected_lang) try: # Create FIXED prediction function predict_fn = self.create_prediction_function(model, tokenizer, self.model_manager.device) # Test the prediction function first test_pred = predict_fn([text]) if test_pred is None or len(test_pred) == 0: return "Prediction function test failed", None, {} # Determine class names based on model output num_classes = test_pred.shape[1] if len(test_pred.shape) > 1 else 2 if num_classes == 3: class_names = ['Negative', 'Neutral', 'Positive'] else: class_names = ['Negative', 'Positive'] # Initialize LIME explainer - FIXED: Remove 'mode' parameter explainer = LimeTextExplainer(class_names=class_names) # Get LIME explanation exp = explainer.explain_instance( text, predict_fn, num_features=min(20, len(text.split())), # Limit features num_samples=num_samples ) # Extract feature importance lime_data = exp.as_list() if not lime_data: return "No LIME features extracted", None, {} # Create visualization words = [item[0] for item in lime_data] scores = [item[1] for item in lime_data] fig = go.Figure() colors = ['red' if s < 0 else 'green' for s in scores] fig.add_trace(go.Bar( y=words, x=scores, orientation='h', marker_color=colors, text=[f'{s:.3f}' for s in scores], textposition='auto', name='LIME Importance', hovertemplate='%{y}
Importance: %{x:.4f}' )) fig.update_layout( title=f"LIME Analysis - Feature Importance (Samples: {num_samples})", xaxis_title="Importance Score", yaxis_title="Words/Phrases", height=500 ) # Create analysis summary analysis_data = { 'method': 'LIME', 'language': detected_lang, 'features_analyzed': len(lime_data), 'samples_used': num_samples, 'positive_features': sum(1 for _, score in lime_data if score > 0), 'negative_features': sum(1 for _, score in lime_data if score < 0), 'feature_importance': lime_data } summary_text = f""" **LIME Analysis Results:** - **Language:** {detected_lang.upper()} - **Features Analyzed:** {analysis_data['features_analyzed']} - **Classes:** {', '.join(class_names)} - **Samples Used:** {num_samples} - **Positive Features:** {analysis_data['positive_features']} - **Negative Features:** {analysis_data['negative_features']} - **Top Features:** {', '.join([f"{word}({score:.3f})" for word, score in lime_data[:5]])} - **Status:** LIME analysis completed successfully """ return summary_text, fig, analysis_data except Exception as e: logger.error(f"LIME analysis failed: {e}") error_msg = f""" **LIME Analysis Failed:** - **Error:** {str(e)} - **Language:** {detected_lang.upper()} - **Suggestion:** Try with a shorter text or reduce number of samples **Bug Fix Applied:** - ✅ Removed 'mode' parameter from LimeTextExplainer initialization - ✅ This should resolve the "unexpected keyword argument 'mode'" error **Common fixes:** - Reduce sample size to 50-100 - Use shorter input text (< 200 words) - Check if model supports the text language """ return error_msg, None, {}