Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	devjas1
				
			
		FEAT(transparent_ai): refine hypothesis generation by removing unused spectral data parameter
		68f2a01
		
		| """ | |
| Transparent AI Reasoning Engine for POLYMEROS | |
| Provides explainable predictions with uncertainty quantification and hypothesis generation | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from typing import Dict, List, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| import warnings | |
| try: | |
| import shap | |
| SHAP_AVAILABLE = True | |
| except ImportError: | |
| SHAP_AVAILABLE = False | |
| warnings.warn("SHAP not available. Install with: pip install shap") | |
| class PredictionExplanation: | |
| """Comprehensive explanation for a model prediction""" | |
| prediction: int | |
| confidence: float | |
| confidence_level: str | |
| probabilities: np.ndarray | |
| feature_importance: Dict[str, float] | |
| reasoning_chain: List[str] | |
| uncertainty_sources: List[str] | |
| similar_cases: List[Dict[str, Any]] | |
| confidence_intervals: Dict[str, Tuple[float, float]] | |
| class Hypothesis: | |
| """AI-generated scientific hypothesis""" | |
| statement: str | |
| confidence: float | |
| supporting_evidence: List[str] | |
| testable_predictions: List[str] | |
| suggested_experiments: List[str] | |
| related_literature: List[str] | |
| class UncertaintyEstimator: | |
| """Bayesian uncertainty estimation for model predictions""" | |
| def __init__(self, model, n_samples: int = 100): | |
| self.model = model | |
| self.n_samples = n_samples | |
| self.epistemic_uncertainty = None | |
| self.aleatoric_uncertainty = None | |
| def estimate_uncertainty(self, x: torch.Tensor) -> Dict[str, float]: | |
| """Estimate prediction uncertainty using Monte Carlo dropout""" | |
| self.model.train() # Enable dropout | |
| predictions = [] | |
| with torch.no_grad(): | |
| for _ in range(self.n_samples): | |
| pred = F.softmax(self.model(x), dim=1) | |
| predictions.append(pred.cpu().numpy()) | |
| predictions = np.array(predictions) | |
| # Calculate uncertainties | |
| mean_pred = np.mean(predictions, axis=0) | |
| epistemic = np.var(predictions, axis=0) # Model uncertainty | |
| aleatoric = np.mean(predictions * (1 - predictions), axis=0) # Data uncertainty | |
| total_uncertainty = epistemic + aleatoric | |
| return { | |
| "epistemic": float(np.mean(epistemic)), | |
| "aleatoric": float(np.mean(aleatoric)), | |
| "total": float(np.mean(total_uncertainty)), | |
| "prediction_variance": float(np.var(mean_pred)), | |
| } | |
| def confidence_intervals( | |
| self, x: torch.Tensor, confidence_level: float = 0.95 | |
| ) -> Dict[str, Tuple[float, float]]: | |
| """Calculate confidence intervals for predictions""" | |
| self.model.train() | |
| predictions = [] | |
| with torch.no_grad(): | |
| for _ in range(self.n_samples): | |
| pred = F.softmax(self.model(x), dim=1) | |
| predictions.append(pred.cpu().numpy().flatten()) | |
| predictions = np.array(predictions) | |
| alpha = 1 - confidence_level | |
| lower_percentile = (alpha / 2) * 100 | |
| upper_percentile = (1 - alpha / 2) * 100 | |
| intervals = {} | |
| for i in range(predictions.shape[1]): | |
| lower = np.percentile(predictions[:, i], lower_percentile) | |
| upper = np.percentile(predictions[:, i], upper_percentile) | |
| intervals[f"class_{i}"] = (lower, upper) | |
| return intervals | |
| class FeatureImportanceAnalyzer: | |
| """Advanced feature importance analysis for spectral data""" | |
| def __init__(self, model): | |
| self.model = model | |
| self.shap_explainer = None | |
| if SHAP_AVAILABLE: | |
| try: | |
| # Initialize SHAP explainer for the model | |
| if SHAP_AVAILABLE: | |
| if SHAP_AVAILABLE: | |
| self.shap_explainer = shap.DeepExplainer( # type: ignore | |
| model, torch.zeros(1, 500) | |
| ) | |
| else: | |
| self.shap_explainer = None | |
| else: | |
| self.shap_explainer = None | |
| except (ValueError, RuntimeError) as e: | |
| warnings.warn(f"Could not initialize SHAP explainer: {e}") | |
| def analyze_feature_importance( | |
| self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None | |
| ) -> Dict[str, Any]: | |
| """Comprehensive feature importance analysis""" | |
| importance_data = {} | |
| # SHAP analysis (if available) | |
| if self.shap_explainer is not None: | |
| try: | |
| shap_values = self.shap_explainer.shap_values(x) | |
| importance_data["shap_values"] = shap_values | |
| importance_data["shap_available"] = True | |
| except (ValueError, RuntimeError) as e: | |
| warnings.warn(f"SHAP analysis failed: {e}") | |
| importance_data["shap_available"] = False | |
| else: | |
| importance_data["shap_available"] = False | |
| # Gradient-based importance | |
| x.requires_grad_(True) | |
| self.model.eval() | |
| output = self.model(x) | |
| predicted_class = torch.argmax(output, dim=1) | |
| # Calculate gradients | |
| self.model.zero_grad() | |
| output[0, predicted_class].backward() | |
| if x.grad is not None: | |
| gradients = x.grad.detach().abs().cpu().numpy().flatten() | |
| else: | |
| raise RuntimeError( | |
| "Gradients were not computed. Ensure x.requires_grad_(True) is set correctly." | |
| ) | |
| importance_data["gradient_importance"] = gradients | |
| # Integrated gradients approximation | |
| integrated_grads = self._integrated_gradients(x, predicted_class) | |
| importance_data["integrated_gradients"] = integrated_grads | |
| # Spectral region importance | |
| if wavenumbers is not None: | |
| region_importance = self._analyze_spectral_regions(gradients, wavenumbers) | |
| importance_data["spectral_regions"] = region_importance | |
| return importance_data | |
| def _integrated_gradients( | |
| self, x: torch.Tensor, target_class: torch.Tensor, steps: int = 50 | |
| ) -> np.ndarray: | |
| """Calculate integrated gradients for feature importance""" | |
| baseline = torch.zeros_like(x) | |
| integrated_grads = np.zeros(x.shape[1]) | |
| for i in range(steps): | |
| alpha = i / steps | |
| interpolated = baseline + alpha * (x - baseline) | |
| interpolated.requires_grad_(True) | |
| output = self.model(interpolated) | |
| self.model.zero_grad() | |
| output[0, target_class].backward(retain_graph=True) | |
| if interpolated.grad is not None: | |
| grads = interpolated.grad.cpu().numpy().flatten() | |
| integrated_grads += grads | |
| integrated_grads = ( | |
| integrated_grads * (x - baseline).detach().cpu().numpy().flatten() / steps | |
| ) | |
| return integrated_grads | |
| def _analyze_spectral_regions( | |
| self, importance: np.ndarray, wavenumbers: np.ndarray | |
| ) -> Dict[str, float]: | |
| """Analyze importance by common spectral regions""" | |
| regions = { | |
| "fingerprint": (400, 1500), | |
| "ch_stretch": (2800, 3100), | |
| "oh_stretch": (3200, 3700), | |
| "carbonyl": (1600, 1800), | |
| "aromatic": (1450, 1650), | |
| } | |
| region_importance = {} | |
| for region_name, (low, high) in regions.items(): | |
| mask = (wavenumbers >= low) & (wavenumbers <= high) | |
| if np.any(mask): | |
| region_importance[region_name] = float(np.mean(importance[mask])) | |
| else: | |
| region_importance[region_name] = 0.0 | |
| return region_importance | |
| class HypothesisGenerator: | |
| """AI-driven scientific hypothesis generation""" | |
| def __init__(self): | |
| self.hypothesis_templates = [ | |
| "The spectral differences in the {region} region suggest {mechanism} as a primary degradation pathway", | |
| "Enhanced intensity at {wavenumber} cm⁻¹ indicates {chemical_change} in weathered samples", | |
| "The correlation between {feature1} and {feature2} suggests {relationship}", | |
| "Baseline shifts in {region} region may indicate {structural_change}", | |
| ] | |
| def generate_hypotheses( | |
| self, explanation: PredictionExplanation | |
| ) -> List[Hypothesis]: | |
| """Generate testable hypotheses based on model predictions and explanations""" | |
| hypotheses = [] | |
| # Analyze feature importance for hypothesis generation | |
| important_features = self._identify_key_features(explanation.feature_importance) | |
| for feature_info in important_features: | |
| hypothesis = self._generate_single_hypothesis(feature_info, explanation) | |
| if hypothesis: | |
| hypotheses.append(hypothesis) | |
| return hypotheses | |
| def _identify_key_features( | |
| self, feature_importance: Dict[str, float] | |
| ) -> List[Dict[str, Any]]: | |
| """Identify key features for hypothesis generation""" | |
| # Sort features by importance | |
| sorted_features = sorted( | |
| feature_importance.items(), key=lambda x: abs(x[1]), reverse=True | |
| ) | |
| key_features = [] | |
| for feature_name, importance in sorted_features[:5]: # Top 5 features | |
| feature_info = { | |
| "name": feature_name, | |
| "importance": importance, | |
| "type": self._classify_feature_type(feature_name), | |
| "chemical_significance": self._get_chemical_significance(feature_name), | |
| } | |
| key_features.append(feature_info) | |
| return key_features | |
| def _classify_feature_type(self, feature_name: str) -> str: | |
| """Classify spectral feature type""" | |
| if "fingerprint" in feature_name.lower(): | |
| return "fingerprint" | |
| elif "stretch" in feature_name.lower(): | |
| return "vibrational" | |
| elif "carbonyl" in feature_name.lower(): | |
| return "functional_group" | |
| else: | |
| return "general" | |
| def _get_chemical_significance(self, feature_name: str) -> str: | |
| """Get chemical significance of spectral feature""" | |
| significance_map = { | |
| "fingerprint": "molecular backbone structure", | |
| "ch_stretch": "aliphatic chain integrity", | |
| "oh_stretch": "hydrogen bonding and hydration", | |
| "carbonyl": "oxidative degradation products", | |
| "aromatic": "aromatic ring preservation", | |
| } | |
| for key, significance in significance_map.items(): | |
| if key in feature_name.lower(): | |
| return significance | |
| return "structural changes" | |
| def _generate_single_hypothesis( | |
| self, feature_info: Dict[str, Any], explanation: PredictionExplanation | |
| ) -> Optional[Hypothesis]: | |
| """Generate a single hypothesis from feature information""" | |
| if feature_info["importance"] < 0.1: # Skip low-importance features | |
| return None | |
| # Create hypothesis statement | |
| statement = f"Changes in {feature_info['name']} region indicate {feature_info['chemical_significance']} during polymer weathering" | |
| # Generate supporting evidence | |
| evidence = [ | |
| f"Feature importance score: {feature_info['importance']:.3f}", | |
| f"Classification confidence: {explanation.confidence:.3f}", | |
| f"Chemical significance: {feature_info['chemical_significance']}", | |
| ] | |
| # Generate testable predictions | |
| predictions = [ | |
| f"Controlled weathering experiments should show progressive changes in {feature_info['name']} region", | |
| f"Different polymer types should exhibit varying {feature_info['name']} responses to weathering", | |
| ] | |
| # Suggest experiments | |
| experiments = [ | |
| f"Time-series weathering study monitoring {feature_info['name']} region", | |
| f"Comparative analysis across polymer types focusing on {feature_info['chemical_significance']}", | |
| "Cross-validation with other analytical techniques (DSC, GPC, etc.)", | |
| ] | |
| return Hypothesis( | |
| statement=statement, | |
| confidence=min(0.9, feature_info["importance"] * explanation.confidence), | |
| supporting_evidence=evidence, | |
| testable_predictions=predictions, | |
| suggested_experiments=experiments, | |
| related_literature=[], # Could be populated with literature search | |
| ) | |
| class TransparentAIEngine: | |
| """Main transparent AI engine combining all reasoning components""" | |
| def __init__(self, model): | |
| self.model = model | |
| self.uncertainty_estimator = UncertaintyEstimator(model) | |
| self.feature_analyzer = FeatureImportanceAnalyzer(model) | |
| self.hypothesis_generator = HypothesisGenerator() | |
| def predict_with_explanation( | |
| self, x: torch.Tensor, wavenumbers: Optional[np.ndarray] = None | |
| ) -> PredictionExplanation: | |
| """Generate comprehensive prediction with full explanation""" | |
| self.model.eval() | |
| # Get basic prediction | |
| with torch.no_grad(): | |
| logits = self.model(x) | |
| probabilities = F.softmax(logits, dim=1).cpu().numpy().flatten() | |
| prediction = int(torch.argmax(logits, dim=1).item()) | |
| confidence = float(np.max(probabilities)) | |
| # Determine confidence level | |
| if confidence >= 0.80: | |
| confidence_level = "HIGH" | |
| elif confidence >= 0.60: | |
| confidence_level = "MEDIUM" | |
| else: | |
| confidence_level = "LOW" | |
| # Get uncertainty estimation | |
| uncertainties = self.uncertainty_estimator.estimate_uncertainty(x) | |
| confidence_intervals = self.uncertainty_estimator.confidence_intervals(x) | |
| # Analyze feature importance | |
| importance_data = self.feature_analyzer.analyze_feature_importance( | |
| x, wavenumbers | |
| ) | |
| # Create feature importance dictionary | |
| if wavenumbers is not None and "spectral_regions" in importance_data: | |
| feature_importance = importance_data["spectral_regions"] | |
| else: | |
| # Use gradient importance | |
| gradients = importance_data.get("gradient_importance", []) | |
| feature_importance = { | |
| f"feature_{i}": float(val) for i, val in enumerate(gradients[:10]) | |
| } | |
| # Generate reasoning chain | |
| reasoning_chain = self._generate_reasoning_chain( | |
| prediction, confidence, feature_importance, uncertainties | |
| ) | |
| # Identify uncertainty sources | |
| uncertainty_sources = self._identify_uncertainty_sources(uncertainties) | |
| # Create explanation object | |
| explanation = PredictionExplanation( | |
| prediction=prediction, | |
| confidence=confidence, | |
| confidence_level=confidence_level, | |
| probabilities=probabilities, | |
| feature_importance=feature_importance, | |
| reasoning_chain=reasoning_chain, | |
| uncertainty_sources=uncertainty_sources, | |
| similar_cases=[], # Could be populated with case-based reasoning | |
| confidence_intervals=confidence_intervals, | |
| ) | |
| return explanation | |
| def generate_hypotheses( | |
| self, explanation: PredictionExplanation | |
| ) -> List[Hypothesis]: | |
| """Generate scientific hypotheses based on prediction explanation""" | |
| return self.hypothesis_generator.generate_hypotheses(explanation) | |
| def _generate_reasoning_chain( | |
| self, | |
| prediction: int, | |
| confidence: float, | |
| feature_importance: Dict[str, float], | |
| uncertainties: Dict[str, float], | |
| ) -> List[str]: | |
| """Generate human-readable reasoning chain""" | |
| reasoning = [] | |
| # Start with prediction | |
| class_names = ["Stable", "Weathered"] | |
| reasoning.append( | |
| f"Model predicts: {class_names[prediction]} (confidence: {confidence:.3f})" | |
| ) | |
| # Add feature analysis | |
| top_features = sorted( | |
| feature_importance.items(), key=lambda x: abs(x[1]), reverse=True | |
| )[:3] | |
| for feature, importance in top_features: | |
| reasoning.append( | |
| f"Key evidence: {feature} region shows importance score {importance:.3f}" | |
| ) | |
| # Add uncertainty analysis | |
| total_uncertainty = uncertainties.get("total", 0) | |
| if total_uncertainty > 0.1: | |
| reasoning.append( | |
| f"High uncertainty detected ({total_uncertainty:.3f}) - suggests ambiguous case" | |
| ) | |
| # Add confidence assessment | |
| if confidence > 0.8: | |
| reasoning.append( | |
| "High confidence: Strong spectral signature for classification" | |
| ) | |
| elif confidence > 0.6: | |
| reasoning.append("Medium confidence: Some ambiguity in spectral features") | |
| else: | |
| reasoning.append("Low confidence: Weak or conflicting spectral evidence") | |
| return reasoning | |
| def _identify_uncertainty_sources( | |
| self, uncertainties: Dict[str, float] | |
| ) -> List[str]: | |
| """Identify sources of prediction uncertainty""" | |
| sources = [] | |
| epistemic = uncertainties.get("epistemic", 0) | |
| aleatoric = uncertainties.get("aleatoric", 0) | |
| if epistemic > 0.05: | |
| sources.append( | |
| "Model uncertainty: Limited training data for this type of spectrum" | |
| ) | |
| if aleatoric > 0.05: | |
| sources.append("Data uncertainty: Noisy or degraded spectral quality") | |
| if uncertainties.get("prediction_variance", 0) > 0.1: | |
| sources.append("Prediction instability: Multiple possible interpretations") | |
| if not sources: | |
| sources.append("Low uncertainty: Clear and unambiguous classification") | |
| return sources | |
