Medical-VQA / src /utils /early_stopping.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
"""
Advanced early stopping with multi-metric support.
Prevents overfitting by tracking multiple metrics simultaneously.
"""
import numpy as np
from pathlib import Path
import torch
import json
class MultiMetricEarlyStopping:
"""
Early stopping that considers multiple metrics with weighted scores.
Advantages over single-metric stopping:
- Prevents overfitting on one metric while degrading others
- Better general model performance
- More stable convergence
Example metric weights:
{'loss': 0.2, 'accuracy': 0.4, 'bertscore': 0.3, 'f1': 0.1}
"""
def __init__(self, patience=5, metric_weights=None, mode='maximize',
save_dir=None, verbose=True):
"""
Args:
patience: Number of evaluations with no improvement before stopping
metric_weights: Dict of {metric_name: weight}. If None, uses 'loss' only
mode: 'maximize' or 'minimize'
save_dir: Directory to save best model
verbose: Print progress
"""
self.patience = patience
self.counter = 0
self.best_score = None
self.best_metrics = None
self.save_dir = Path(save_dir) if save_dir else None
self.verbose = verbose
self.mode = mode
# Default metric weights if not provided
if metric_weights is None:
self.metric_weights = {'loss': 1.0}
else:
self.metric_weights = metric_weights
# Normalize weights to sum to 1
total_weight = sum(self.metric_weights.values())
self.metric_weights = {k: v/total_weight for k, v in self.metric_weights.items()}
self.history = []
if self.save_dir:
self.save_dir.mkdir(parents=True, exist_ok=True)
def compute_score(self, metrics):
"""
Compute weighted score from multiple metrics.
Args:
metrics: Dict of metric_name -> value
Returns:
Weighted score
"""
score = 0.0
for metric_name, weight in self.metric_weights.items():
if metric_name not in metrics:
if self.verbose:
print(f"[WARNING] Metric '{metric_name}' not found in current metrics")
continue
metric_value = metrics[metric_name]
# Handle loss (we want to minimize it)
if 'loss' in metric_name.lower():
# Invert loss for maximization context
metric_contribution = -metric_value if self.mode == 'maximize' else metric_value
else:
# Most metrics should be maximized (accuracy, F1, etc.)
metric_contribution = metric_value
score += metric_contribution * weight
return score
def __call__(self, metrics, model=None, epoch=None):
"""
Check if should stop training.
Args:
metrics: Dict of metric_name -> value
model: Model to save if best
epoch: Current epoch number
Returns:
True if should stop, False otherwise
"""
score = self.compute_score(metrics)
# Store history
self.history.append({
'epoch': epoch,
'score': score,
'metrics': metrics.copy()
})
if self.best_score is None:
self.best_score = score
self.best_metrics = metrics.copy()
if model is not None and self.save_dir:
self._save_checkpoint(model, epoch, metrics)
elif score > self.best_score:
self.best_score = score
self.best_metrics = metrics.copy()
self.counter = 0
if model is not None and self.save_dir:
self._save_checkpoint(model, epoch, metrics)
if self.verbose:
print(f"✓ Epoch {epoch}: New best score {score:.4f}")
else:
self.counter += 1
if self.verbose:
print(f"✗ Epoch {epoch}: No improvement ({self.counter}/{self.patience})")
# Check if should stop
if self.counter >= self.patience:
if self.verbose:
print(f"\n[EARLY STOPPING] Patience exceeded. Best metrics:")
for k, v in self.best_metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
return True
return False
def _save_checkpoint(self, model, epoch, metrics):
"""Save best model checkpoint."""
if self.save_dir is None:
return
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'metrics': metrics
}
save_path = self.save_dir / f"best_checkpoint_epoch_{epoch}.pt"
torch.save(checkpoint, save_path)
# Also save metrics record
metrics_path = self.save_dir / f"best_metrics_epoch_{epoch}.json"
with open(metrics_path, 'w') as f:
json.dump(metrics, f, indent=2, default=str)
if self.verbose:
print(f" 💾 Saved checkpoint to {save_path}")
def get_best_metrics(self):
"""Return best metrics found during training."""
return self.best_metrics
def get_history(self):
"""Return training history."""
return self.history
def plot_metrics(self, save_path=None):
"""
Plot metric progression during training.
Args:
save_path: Path to save figure
"""
try:
import matplotlib.pyplot as plt
except ImportError:
print("[WARNING] matplotlib not installed, cannot plot")
return
if not self.history:
print("[WARNING] No history to plot")
return
epochs = [h['epoch'] for h in self.history]
scores = [h['score'] for h in self.history]
plt.figure(figsize=(10, 6))
plt.plot(epochs, scores, 'b-o', label='Composite Score')
plt.axhline(y=self.best_score, color='r', linestyle='--', label=f'Best: {self.best_score:.4f}')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()
plt.title('Early Stopping - Composite Metric Score')
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"[INFO] Metric plot saved to {save_path}")
plt.close()
class DynamicClassWeights:
"""
Compute class weights dynamically from training data.
Adapts to actual data distribution.
"""
@staticmethod
def compute_weights(dataloader, device='cpu'):
"""
Compute class weights from data distribution.
Args:
dataloader: DataLoader to analyze
device: Device for tensor
Returns:
Tensor of class weights
"""
class_counts = {}
for batch in dataloader:
labels = batch.get('label_closed', None)
if labels is None:
continue
# Count occurrences of each class
unique_labels, counts = torch.unique(labels, return_counts=True)
for label, count in zip(unique_labels, counts):
label_idx = label.item()
if label_idx >= 0: # Ignore negative indices
class_counts[label_idx] = class_counts.get(label_idx, 0) + count.item()
if not class_counts:
# Default weights if no data found
return torch.ones(2, device=device)
# Compute inverse frequency weights
total_samples = sum(class_counts.values())
num_classes = len(class_counts)
weights = torch.zeros(max(class_counts.keys()) + 1, device=device)
for class_idx, count in class_counts.items():
# Weight = total / (num_classes * count) - higher weight for rarer classes
weight = total_samples / (num_classes * max(count, 1))
weights[class_idx] = weight
# Normalize to sum to num_classes
weights = weights / weights.sum() * num_classes
print("[INFO] Dynamic Class Weights:")
for class_idx in sorted(class_counts.keys()):
print(f" Class {class_idx}: Weight={weights[class_idx]:.4f}, Samples={class_counts[class_idx]}")
return weights.to(device)