code2-repo / evaluator.py
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
"""
Evaluation and Analysis Tools for Legal-BERT
"""
import torch
import numpy as np
import json
from typing import Dict, List, Any, Tuple
from collections import defaultdict
# Try to import visualization libraries
try:
import matplotlib.pyplot as plt
import seaborn as sns
VISUALIZATION_AVAILABLE = True
except ImportError:
VISUALIZATION_AVAILABLE = False
print("⚠️ Warning: matplotlib/seaborn not available. Visualizations will be skipped.")
# Import hierarchical risk analysis
try:
from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer
HIERARCHICAL_AVAILABLE = True
except ImportError:
HIERARCHICAL_AVAILABLE = False
print("⚠️ Warning: hierarchical_risk module not available.")
class LegalBertEvaluator:
"""
Comprehensive evaluation for Legal-BERT with discovered risk patterns
"""
def __init__(self, model, tokenizer, risk_discovery):
self.model = model
self.tokenizer = tokenizer
self.risk_discovery = risk_discovery
self.evaluation_results = {}
def evaluate_model(self, test_loader, save_results: bool = True) -> Dict[str, Any]:
"""Comprehensive model evaluation"""
print("🔍 Starting comprehensive evaluation...")
# Collect predictions
all_predictions = []
all_true_labels = []
all_severity_preds = []
all_severity_true = []
all_importance_preds = []
all_importance_true = []
all_confidences = []
self.model.eval()
with torch.no_grad():
for batch in test_loader:
device = next(self.model.parameters()).device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
# Get predictions using the correct method
outputs = self.model.forward_single_clause(input_ids, attention_mask)
# Calculate predictions and confidences from logits
risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1)
predicted_risk_ids = torch.argmax(risk_probs, dim=-1)
confidences = torch.max(risk_probs, dim=-1)[0]
# Store results
all_predictions.extend(predicted_risk_ids.cpu().numpy())
all_true_labels.extend(batch['risk_label'].numpy())
all_severity_preds.extend(outputs['severity_score'].cpu().numpy())
all_severity_true.extend(batch['severity_score'].numpy())
all_importance_preds.extend(outputs['importance_score'].cpu().numpy())
all_importance_true.extend(batch['importance_score'].numpy())
all_confidences.extend(confidences.cpu().numpy())
# Calculate metrics
results = {
'classification_metrics': self._calculate_classification_metrics(
all_true_labels, all_predictions, all_confidences
),
'regression_metrics': self._calculate_regression_metrics(
all_severity_true, all_severity_preds,
all_importance_true, all_importance_preds
),
'risk_pattern_analysis': self._analyze_risk_patterns(
all_true_labels, all_predictions
)
}
self.evaluation_results = results
if save_results:
self.save_evaluation_results(results)
print("✅ Evaluation complete!")
return results
def _calculate_classification_metrics(self, true_labels: List[int],
predictions: List[int],
confidences: List[float]) -> Dict[str, Any]:
"""Calculate classification metrics"""
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
accuracy = accuracy_score(true_labels, predictions)
precision, recall, f1, support = precision_recall_fscore_support(
true_labels, predictions, average='weighted'
)
# Per-class metrics
precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
true_labels, predictions, average=None
)
# Confusion matrix
cm = confusion_matrix(true_labels, predictions)
# Confidence analysis
avg_confidence = np.mean(confidences)
confidence_std = np.std(confidences)
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'precision_per_class': precision_per_class.tolist(),
'recall_per_class': recall_per_class.tolist(),
'f1_per_class': f1_per_class.tolist(),
'confusion_matrix': cm.tolist(),
'avg_confidence': avg_confidence,
'confidence_std': confidence_std
}
def _calculate_regression_metrics(self, severity_true: List[float], severity_pred: List[float],
importance_true: List[float], importance_pred: List[float]) -> Dict[str, Any]:
"""Calculate regression metrics"""
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
# Severity metrics
severity_mse = mean_squared_error(severity_true, severity_pred)
severity_mae = mean_absolute_error(severity_true, severity_pred)
severity_r2 = r2_score(severity_true, severity_pred)
# Importance metrics
importance_mse = mean_squared_error(importance_true, importance_pred)
importance_mae = mean_absolute_error(importance_true, importance_pred)
importance_r2 = r2_score(importance_true, importance_pred)
return {
'severity': {
'mse': severity_mse,
'mae': severity_mae,
'r2_score': severity_r2
},
'importance': {
'mse': importance_mse,
'mae': importance_mae,
'r2_score': importance_r2
}
}
def _analyze_risk_patterns(self, true_labels: List[int], predictions: List[int]) -> Dict[str, Any]:
"""Analyze discovered risk patterns"""
discovered_patterns = self.risk_discovery.discovered_patterns
pattern_names = list(discovered_patterns.keys())
# Pattern distribution
true_distribution = defaultdict(int)
pred_distribution = defaultdict(int)
for label in true_labels:
true_distribution[pattern_names[label]] += 1
for pred in predictions:
pred_distribution[pattern_names[pred]] += 1
# Pattern-specific performance
pattern_performance = {}
for i, pattern_name in enumerate(pattern_names):
pattern_true = [1 if label == i else 0 for label in true_labels]
pattern_pred = [1 if pred == i else 0 for pred in predictions]
if sum(pattern_true) > 0: # Avoid division by zero
precision = sum([1 for t, p in zip(pattern_true, pattern_pred) if t == 1 and p == 1]) / max(sum(pattern_pred), 1)
recall = sum([1 for t, p in zip(pattern_true, pattern_pred) if t == 1 and p == 1]) / sum(pattern_true)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
pattern_performance[pattern_name] = {
'precision': precision,
'recall': recall,
'f1_score': f1,
'support': sum(pattern_true)
}
return {
'true_distribution': dict(true_distribution),
'predicted_distribution': dict(pred_distribution),
'pattern_performance': pattern_performance,
'discovered_patterns_info': discovered_patterns
}
def generate_report(self) -> str:
"""Generate comprehensive evaluation report"""
if not self.evaluation_results:
raise ValueError("Must run evaluation first")
results = self.evaluation_results
report = []
report.append("=" * 80)
report.append("🏛️ LEGAL-BERT EVALUATION REPORT")
report.append("=" * 80)
# Classification Performance
report.append("\n📊 RISK CLASSIFICATION PERFORMANCE")
report.append("-" * 50)
clf_metrics = results['classification_metrics']
report.append(f"Accuracy: {clf_metrics['accuracy']:.4f}")
report.append(f"Precision: {clf_metrics['precision']:.4f}")
report.append(f"Recall: {clf_metrics['recall']:.4f}")
report.append(f"F1-Score: {clf_metrics['f1_score']:.4f}")
report.append(f"Average Confidence: {clf_metrics['avg_confidence']:.4f}")
# Regression Performance
report.append("\n📈 REGRESSION PERFORMANCE")
report.append("-" * 50)
reg_metrics = results['regression_metrics']
report.append("Severity Prediction:")
report.append(f" MSE: {reg_metrics['severity']['mse']:.4f}")
report.append(f" MAE: {reg_metrics['severity']['mae']:.4f}")
report.append(f" R²: {reg_metrics['severity']['r2_score']:.4f}")
report.append("Importance Prediction:")
report.append(f" MSE: {reg_metrics['importance']['mse']:.4f}")
report.append(f" MAE: {reg_metrics['importance']['mae']:.4f}")
report.append(f" R²: {reg_metrics['importance']['r2_score']:.4f}")
# Risk Pattern Analysis
report.append("\n🔍 DISCOVERED RISK PATTERNS")
report.append("-" * 50)
pattern_analysis = results['risk_pattern_analysis']
report.append("Pattern Distribution (True vs Predicted):")
for pattern, count in pattern_analysis['true_distribution'].items():
pred_count = pattern_analysis['predicted_distribution'].get(pattern, 0)
report.append(f" {pattern}: {count}{pred_count}")
report.append("\nPattern-Specific Performance:")
for pattern, metrics in pattern_analysis['pattern_performance'].items():
report.append(f" {pattern}:")
report.append(f" Precision: {metrics['precision']:.4f}")
report.append(f" Recall: {metrics['recall']:.4f}")
report.append(f" F1-Score: {metrics['f1_score']:.4f}")
report.append(f" Support: {metrics['support']}")
# Discovered Patterns Info
report.append("\n🎯 DISCOVERED PATTERN DETAILS")
report.append("-" * 50)
for pattern_name, details in pattern_analysis['discovered_patterns_info'].items():
report.append(f"\n{pattern_name}:")
# Handle different pattern structures (LDA vs K-Means)
if 'clause_count' in details:
report.append(f" Clauses: {details['clause_count']}")
if 'avg_risk_intensity' in details:
report.append(f" Risk Intensity: {details['avg_risk_intensity']:.3f}")
if 'avg_legal_complexity' in details:
report.append(f" Legal Complexity: {details['avg_legal_complexity']:.3f}")
# Handle both 'key_terms' and 'top_words' (LDA uses top_words)
if 'key_terms' in details:
report.append(f" Key Terms: {', '.join(details['key_terms'][:5])}")
elif 'top_words' in details:
report.append(f" Top Words: {', '.join(details['top_words'][:5])}")
# Show topic distribution if available (LDA-specific)
if 'topic_distribution' in details:
report.append(f" Topic Distribution: {details['topic_distribution']:.3f}")
report.append("\n" + "=" * 80)
return "\n".join(report)
def plot_confusion_matrix(self, save_path: str = None):
"""Plot confusion matrix"""
if not VISUALIZATION_AVAILABLE:
print("⚠️ Visualization libraries not available. Skipping plot.")
return
if not self.evaluation_results:
raise ValueError("Must run evaluation first")
cm = np.array(self.evaluation_results['classification_metrics']['confusion_matrix'])
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix - Risk Classification')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"💾 Confusion matrix saved to: {save_path}")
else:
plt.show()
plt.close()
def plot_risk_distribution(self, save_path: str = None):
"""Plot risk pattern distribution"""
if not VISUALIZATION_AVAILABLE:
print("⚠️ Visualization libraries not available. Skipping plot.")
return
if not self.evaluation_results:
raise ValueError("Must run evaluation first")
pattern_analysis = self.evaluation_results['risk_pattern_analysis']
patterns = list(pattern_analysis['true_distribution'].keys())
true_counts = [pattern_analysis['true_distribution'][p] for p in patterns]
pred_counts = [pattern_analysis['predicted_distribution'].get(p, 0) for p in patterns]
x = np.arange(len(patterns))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(x - width/2, true_counts, width, label='True', alpha=0.8)
ax.bar(x + width/2, pred_counts, width, label='Predicted', alpha=0.8)
ax.set_xlabel('Risk Patterns')
ax.set_ylabel('Count')
ax.set_title('Risk Pattern Distribution - True vs Predicted')
ax.set_xticks(x)
ax.set_xticklabels(patterns, rotation=45, ha='right')
ax.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"💾 Risk distribution plot saved to: {save_path}")
else:
plt.show()
plt.close()
def save_evaluation_results(self, results: Dict[str, Any]):
"""Save evaluation results to file"""
# Convert numpy arrays to lists for JSON serialization
json_results = self._convert_for_json(results)
with open('evaluation_results.json', 'w') as f:
json.dump(json_results, f, indent=2)
# Save report
report = self.generate_report()
with open('evaluation_report.txt', 'w') as f:
f.write(report)
print("💾 Evaluation results saved:")
print(" - evaluation_results.json")
print(" - evaluation_report.txt")
def _convert_for_json(self, obj):
"""Convert numpy arrays to lists for JSON serialization"""
if isinstance(obj, dict):
return {key: self._convert_for_json(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [self._convert_for_json(item) for item in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
else:
return obj
def analyze_attention_patterns(self, test_clauses: List[str],
max_samples: int = 10) -> Dict[str, Any]:
"""
Analyze attention patterns for clause importance interpretation.
Args:
test_clauses: List of clause texts to analyze
max_samples: Maximum number of samples to analyze
Returns:
Dictionary containing attention analysis results
"""
print(f"🔍 Analyzing attention patterns for {min(len(test_clauses), max_samples)} samples...")
self.model.eval()
attention_results = []
with torch.no_grad():
for idx, clause in enumerate(test_clauses[:max_samples]):
# Tokenize
tokens = self.tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(self.model.config.device)
attention_mask = tokens['attention_mask'].to(self.model.config.device)
# Get attention analysis
analysis = self.model.analyze_attention(input_ids, attention_mask, self.tokenizer)
# Get prediction
prediction = self.model.predict_risk_pattern(input_ids, attention_mask)
result = {
'clause_index': idx,
'clause_preview': clause[:100] + '...' if len(clause) > 100 else clause,
'predicted_risk': int(prediction['predicted_risk_id'][0]),
'severity': float(prediction['severity_score'][0]),
'importance': float(prediction['importance_score'][0]),
'top_tokens': analysis.get('top_tokens', []),
'top_token_scores': analysis.get('top_token_scores', np.array([])).tolist()
}
attention_results.append(result)
print(f"✅ Attention analysis complete for {len(attention_results)} clauses")
return {
'num_analyzed': len(attention_results),
'clause_analyses': attention_results
}
def evaluate_hierarchical_risk(self, test_loader,
contract_ids: List[int]) -> Dict[str, Any]:
"""
Evaluate hierarchical risk aggregation (clause → contract level).
Args:
test_loader: DataLoader with test clauses
contract_ids: List of contract IDs for each clause in test set
Returns:
Contract-level risk assessment results
"""
if not HIERARCHICAL_AVAILABLE:
print("⚠️ Hierarchical risk analysis not available")
return {'error': 'hierarchical_risk module not found'}
print("📊 Performing hierarchical risk evaluation (clause → contract level)...")
# Collect clause-level predictions grouped by contract
contract_predictions = defaultdict(list)
self.model.eval()
clause_idx = 0
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(self.model.config.device)
attention_mask = batch['attention_mask'].to(self.model.config.device)
# Get predictions
predictions = self.model.predict_risk_pattern(input_ids, attention_mask)
# Group by contract
batch_size = input_ids.size(0)
for i in range(batch_size):
contract_id = contract_ids[clause_idx]
clause_pred = {
'predicted_risk_id': int(predictions['predicted_risk_id'][i]),
'confidence': float(predictions['confidence'][i]),
'severity_score': float(predictions['severity_score'][i]),
'importance_score': float(predictions['importance_score'][i])
}
contract_predictions[contract_id].append(clause_pred)
clause_idx += 1
# Aggregate to contract level
aggregator = HierarchicalRiskAggregator()
contract_results = {}
for contract_id, clause_preds in contract_predictions.items():
contract_risk = aggregator.aggregate_contract_risk(
clause_preds,
method='weighted_mean'
)
contract_results[contract_id] = contract_risk
print(f"✅ Analyzed {len(contract_results)} contracts")
# Summary statistics
contract_severities = [r['contract_severity'] for r in contract_results.values()]
contract_importances = [r['contract_importance'] for r in contract_results.values()]
summary = {
'num_contracts': len(contract_results),
'contract_results': contract_results,
'summary_statistics': {
'avg_contract_severity': float(np.mean(contract_severities)),
'std_contract_severity': float(np.std(contract_severities)),
'max_contract_severity': float(np.max(contract_severities)),
'min_contract_severity': float(np.min(contract_severities)),
'avg_contract_importance': float(np.mean(contract_importances)),
'high_risk_contracts': sum(1 for s in contract_severities if s >= 7.0)
}
}
return summary
def analyze_risk_dependencies(self, test_loader,
contract_ids: List[int],
num_risk_types: int = 7) -> Dict[str, Any]:
"""
Analyze dependencies and interactions between risk types.
Args:
test_loader: DataLoader with test clauses
contract_ids: List of contract IDs for each clause
num_risk_types: Number of risk categories
Returns:
Risk dependency analysis including co-occurrence and correlations
"""
if not HIERARCHICAL_AVAILABLE:
print("⚠️ Risk dependency analysis not available")
return {'error': 'hierarchical_risk module not found'}
print("🔗 Analyzing risk dependencies and interactions...")
# Collect predictions grouped by contract
contract_predictions = defaultdict(list)
self.model.eval()
clause_idx = 0
with torch.no_grad():
for batch in test_loader:
input_ids = batch['input_ids'].to(self.model.config.device)
attention_mask = batch['attention_mask'].to(self.model.config.device)
predictions = self.model.predict_risk_pattern(input_ids, attention_mask)
batch_size = input_ids.size(0)
for i in range(batch_size):
contract_id = contract_ids[clause_idx]
clause_pred = {
'predicted_risk_id': int(predictions['predicted_risk_id'][i]),
'confidence': float(predictions['confidence'][i]),
'severity_score': float(predictions['severity_score'][i]),
'importance_score': float(predictions['importance_score'][i])
}
contract_predictions[contract_id].append(clause_pred)
clause_idx += 1
# Analyze dependencies
dependency_analyzer = RiskDependencyAnalyzer()
# Compute correlation across contracts
contract_pred_lists = list(contract_predictions.values())
correlation_matrix = dependency_analyzer.compute_risk_correlation(
contract_pred_lists,
num_risk_types
)
# Analyze amplification effects
all_clause_preds = [pred for preds in contract_pred_lists for pred in preds]
amplification = dependency_analyzer.analyze_risk_amplification(all_clause_preds)
# Find common risk chains
all_chains = []
for clause_preds in contract_pred_lists:
chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3)
all_chains.extend(chains)
# Count most common chains
from collections import Counter
chain_counts = Counter([tuple(chain) for chain in all_chains])
most_common_chains = chain_counts.most_common(10)
print(f"✅ Risk dependency analysis complete")
return {
'correlation_matrix': correlation_matrix.tolist(),
'risk_amplification': amplification,
'common_risk_chains': [
{'chain': list(chain), 'count': count}
for chain, count in most_common_chains
],
'total_chains_found': len(all_chains)
}
# Mock imports for environments without sklearn/matplotlib
try:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
except ImportError:
print("⚠️ Warning: Some evaluation dependencies not available. Using mock implementations.")
# Mock torch
class MockTensor:
def __init__(self, data):
self.data = data
def numpy(self):
return self.data
def to(self, device):
return self
class MockModule:
def eval(self):
pass
def __getattr__(self, name):
return lambda *args, **kwargs: None
torch = type('torch', (), {
'no_grad': lambda: type('context', (), {'__enter__': lambda self: None, '__exit__': lambda *args: None})()
})()
# Mock sklearn functions
def accuracy_score(y_true, y_pred):
return sum([1 for t, p in zip(y_true, y_pred) if t == p]) / len(y_true)
def precision_recall_fscore_support(y_true, y_pred, average=None):
return 0.5, 0.5, 0.5, None
def confusion_matrix(y_true, y_pred):
return [[1, 0], [0, 1]]
def mean_squared_error(y_true, y_pred):
return sum([(t - p) ** 2 for t, p in zip(y_true, y_pred)]) / len(y_true)
def mean_absolute_error(y_true, y_pred):
return sum([abs(t - p) for t, p in zip(y_true, y_pred)]) / len(y_true)
def r2_score(y_true, y_pred):
return 0.5