| | """
|
| | Training Script for TransMIL + Query2Label Hybrid Model
|
| |
|
| | Supports:
|
| | - End-to-end training with ResNet-50 backbone
|
| | - Mixed precision training (AMP) for memory efficiency
|
| | - Gradient accumulation for larger effective batch size
|
| | - Gradient checkpointing for ResNet-50
|
| | - AsymmetricLoss for multi-label imbalance
|
| | - Multi-label evaluation metrics (mAP, per-class AP, F1)
|
| | """
|
| |
|
| | import sys
|
| |
|
| |
|
| |
|
| | import os
|
| | import argparse
|
| | import yaml
|
| | from pathlib import Path
|
| | from datetime import datetime
|
| | import json
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.optim as optim
|
| | from torch.cuda.amp import autocast, GradScaler
|
| | from torch.utils.tensorboard import SummaryWriter
|
| | import numpy as np
|
| | from tqdm import tqdm
|
| | from sklearn.metrics import average_precision_score, f1_score
|
| |
|
| |
|
| | from models.transmil_q2l import TransMIL_Query2Label_E2E
|
| | from thyroid_dataset import create_dataloaders
|
| |
|
| |
|
| | try:
|
| | from models.aslloss import AsymmetricLossOptimized
|
| | except ImportError:
|
| | print("Warning: Could not import AsymmetricLoss.")
|
| | AsymmetricLossOptimized = None
|
| | '''
|
| | try:
|
| | #from aslloss import AsymmetricLossOptimized
|
| | from models.aslloss import AsymmetricLossOptimized
|
| | except ImportError:
|
| | print("Warning: Could not import AsymmetricLoss from query2labels.")
|
| | print("Make sure query2labels/lib/models/aslloss.py is in Python path.")
|
| | AsymmetricLossOptimized = None
|
| |
|
| | '''
|
| |
|
| |
|
| |
|
| |
|
| | def compute_multilabel_metrics(preds, targets, threshold=0.5):
|
| | """
|
| | Compute multi-label classification metrics.
|
| |
|
| | Args:
|
| | preds: [N, num_class] numpy array of probabilities
|
| | targets: [N, num_class] numpy array of binary labels
|
| | threshold: Classification threshold for F1 score
|
| |
|
| | Returns:
|
| | dict with mAP, per-class AP, F1 scores
|
| | """
|
| | metrics = {}
|
| |
|
| |
|
| | aps = []
|
| | for i in range(targets.shape[1]):
|
| | if targets[:, i].sum() > 0:
|
| | ap = average_precision_score(targets[:, i], preds[:, i])
|
| | aps.append(ap)
|
| | else:
|
| | aps.append(np.nan)
|
| |
|
| | metrics['mAP'] = np.nanmean(aps)
|
| | metrics['per_class_AP'] = aps
|
| |
|
| |
|
| | preds_binary = (preds >= threshold).astype(int)
|
| | f1_micro = f1_score(targets, preds_binary, average='micro', zero_division=0)
|
| | f1_macro = f1_score(targets, preds_binary, average='macro', zero_division=0)
|
| | f1_samples = f1_score(targets, preds_binary, average='samples', zero_division=0)
|
| |
|
| | metrics['F1_micro'] = f1_micro
|
| | metrics['F1_macro'] = f1_macro
|
| | metrics['F1_samples'] = f1_samples
|
| |
|
| | return metrics
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def train_epoch(model, dataloader, criterion, optimizer, scaler, device, config, epoch):
|
| | """
|
| | Train for one epoch with gradient accumulation and mixed precision.
|
| |
|
| | Args:
|
| | model: TransMIL_Query2Label_E2E model
|
| | dataloader: Training dataloader
|
| | criterion: AsymmetricLoss
|
| | optimizer: AdamW optimizer
|
| | scaler: GradScaler for AMP
|
| | device: torch.device
|
| | config: Config dict
|
| | epoch: Current epoch number
|
| |
|
| | Returns:
|
| | Average loss for epoch
|
| | """
|
| | model.train()
|
| |
|
| | total_loss = 0.0
|
| | accumulation_steps = config['training']['gradient_accumulation_steps']
|
| | use_amp = config['training']['use_amp']
|
| |
|
| |
|
| | pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
| |
|
| | optimizer.zero_grad()
|
| |
|
| | for i, batch in enumerate(pbar):
|
| | images = batch['images'].to(device)
|
| | labels = batch['labels'].to(device)
|
| | num_instances_per_case = batch['num_instances_per_case']
|
| |
|
| |
|
| | if use_amp:
|
| | with autocast():
|
| | logits = model(images, num_instances_per_case)
|
| | loss = criterion(logits, labels)
|
| | loss = loss / accumulation_steps
|
| | else:
|
| | logits = model(images, num_instances_per_case)
|
| | loss = criterion(logits, labels)
|
| | loss = loss / accumulation_steps
|
| |
|
| |
|
| | if use_amp:
|
| | scaler.scale(loss).backward()
|
| | else:
|
| | loss.backward()
|
| |
|
| |
|
| | if (i + 1) % accumulation_steps == 0:
|
| | if use_amp:
|
| | scaler.step(optimizer)
|
| | scaler.update()
|
| | else:
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| | total_loss += loss.item() * accumulation_steps
|
| | pbar.set_postfix({'loss': loss.item() * accumulation_steps})
|
| |
|
| | return total_loss / len(dataloader)
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def validate(model, dataloader, criterion, device, config):
|
| | """
|
| | Validate model with multi-label metrics.
|
| |
|
| | Args:
|
| | model: TransMIL_Query2Label_E2E model
|
| | dataloader: Validation dataloader
|
| | criterion: AsymmetricLoss
|
| | device: torch.device
|
| | config: Config dict
|
| |
|
| | Returns:
|
| | dict with loss and metrics (mAP, F1, etc.)
|
| | """
|
| | model.eval()
|
| |
|
| | total_loss = 0.0
|
| | all_preds = []
|
| | all_targets = []
|
| |
|
| | for batch in tqdm(dataloader, desc="Validating"):
|
| | images = batch['images'].to(device)
|
| | labels = batch['labels'].to(device)
|
| | num_instances_per_case = batch['num_instances_per_case']
|
| |
|
| |
|
| | logits = model(images, num_instances_per_case)
|
| | loss = criterion(logits, labels)
|
| |
|
| |
|
| | preds = torch.sigmoid(logits)
|
| |
|
| |
|
| | all_preds.append(preds.cpu().numpy())
|
| | all_targets.append(labels.cpu().numpy())
|
| |
|
| | total_loss += loss.item()
|
| |
|
| |
|
| | all_preds = np.concatenate(all_preds, axis=0)
|
| | all_targets = np.concatenate(all_targets, axis=0)
|
| |
|
| |
|
| | metrics = compute_multilabel_metrics(all_preds, all_targets)
|
| | metrics['loss'] = total_loss / len(dataloader)
|
| |
|
| | return metrics
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def train(config, resume_from=None):
|
| | """
|
| | Main training function.
|
| |
|
| | Args:
|
| | config: Config dictionary from YAML
|
| | resume_from: Optional checkpoint path to resume training
|
| | """
|
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| | print(f"\nUsing device: {device}")
|
| | if torch.cuda.is_available():
|
| | print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| | print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| |
|
| |
|
| | save_dir = Path(config['training']['save_dir'])
|
| | save_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | log_dir = save_dir / 'logs' / datetime.now().strftime('%Y%m%d_%H%M%S')
|
| | writer = SummaryWriter(log_dir)
|
| |
|
| |
|
| | with open(save_dir / 'config.yaml', 'w') as f:
|
| | yaml.dump(config, f)
|
| |
|
| |
|
| | print("\nCreating dataloaders...")
|
| | train_loader, val_loader, test_loader = create_dataloaders(config)
|
| |
|
| |
|
| | print("\nCreating model...")
|
| | model = TransMIL_Query2Label_E2E(
|
| | num_class=config['model']['num_class'],
|
| | hidden_dim=config['model']['hidden_dim'],
|
| | nheads=config['model']['nheads'],
|
| | num_decoder_layers=config['model']['num_decoder_layers'],
|
| | pretrained_resnet=config['model']['pretrained_resnet'],
|
| | use_checkpointing=config['training']['gradient_checkpointing'],
|
| | use_ppeg=config['model'].get('use_ppeg', False)
|
| | )
|
| | model = model.to(device)
|
| |
|
| |
|
| | total_params = sum(p.numel() for p in model.parameters())
|
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| | print(f"Total parameters: {total_params:,}")
|
| | print(f"Trainable parameters: {trainable_params:,}")
|
| |
|
| |
|
| | optimizer = optim.AdamW(
|
| | model.parameters(),
|
| | lr=config['training']['lr'],
|
| | weight_decay=config['training']['weight_decay']
|
| | )
|
| |
|
| |
|
| | scheduler_type = config['training'].get('scheduler', 'cosine')
|
| | if scheduler_type == 'cosine':
|
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| | optimizer,
|
| | T_max=config['training']['epochs'],
|
| | eta_min=1e-6
|
| | )
|
| | elif scheduler_type == 'onecycle':
|
| | scheduler = optim.lr_scheduler.OneCycleLR(
|
| | optimizer,
|
| | max_lr=config['training']['lr'],
|
| | epochs=config['training']['epochs'],
|
| | steps_per_epoch=len(train_loader)
|
| | )
|
| | else:
|
| | scheduler = None
|
| |
|
| |
|
| | if AsymmetricLossOptimized is not None:
|
| | criterion = AsymmetricLossOptimized(
|
| | gamma_neg=config['training']['gamma_neg'],
|
| | gamma_pos=config['training']['gamma_pos'],
|
| | clip=config['training']['clip'],
|
| | eps=1e-5
|
| | )
|
| | else:
|
| |
|
| | print("Warning: Using BCEWithLogitsLoss instead of AsymmetricLoss")
|
| | criterion = nn.BCEWithLogitsLoss()
|
| |
|
| |
|
| | scaler = GradScaler() if config['training']['use_amp'] else None
|
| |
|
| |
|
| | start_epoch = 0
|
| | best_map = 0.0
|
| |
|
| | if resume_from is not None and Path(resume_from).exists():
|
| | print(f"\nResuming from {resume_from}")
|
| | checkpoint = torch.load(resume_from, map_location=device)
|
| | model.load_state_dict(checkpoint['model_state_dict'])
|
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| | start_epoch = checkpoint['epoch'] + 1
|
| | best_map = checkpoint.get('best_map', 0.0)
|
| | if scheduler is not None and 'scheduler_state_dict' in checkpoint:
|
| | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| | print(f"Resumed from epoch {start_epoch}, best mAP: {best_map:.4f}")
|
| |
|
| |
|
| | print(f"\nStarting training for {config['training']['epochs']} epochs...")
|
| | print("="*80)
|
| |
|
| | for epoch in range(start_epoch, config['training']['epochs']):
|
| |
|
| | train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, config, epoch)
|
| |
|
| |
|
| | val_metrics = validate(model, val_loader, criterion, device, config)
|
| |
|
| |
|
| | if scheduler is not None:
|
| | if scheduler_type == 'onecycle':
|
| | pass
|
| | else:
|
| | scheduler.step()
|
| |
|
| |
|
| | current_lr = optimizer.param_groups[0]['lr']
|
| | writer.add_scalar('Loss/train', train_loss, epoch)
|
| | writer.add_scalar('Loss/val', val_metrics['loss'], epoch)
|
| | writer.add_scalar('Metrics/mAP', val_metrics['mAP'], epoch)
|
| | writer.add_scalar('Metrics/F1_micro', val_metrics['F1_micro'], epoch)
|
| | writer.add_scalar('Metrics/F1_macro', val_metrics['F1_macro'], epoch)
|
| | writer.add_scalar('LR', current_lr, epoch)
|
| |
|
| |
|
| | print(f"\nEpoch {epoch}/{config['training']['epochs']}")
|
| | print(f" Train Loss: {train_loss:.4f}")
|
| | print(f" Val Loss: {val_metrics['loss']:.4f}")
|
| | print(f" mAP: {val_metrics['mAP']:.4f}")
|
| | print(f" F1 (micro): {val_metrics['F1_micro']:.4f}")
|
| | print(f" F1 (macro): {val_metrics['F1_macro']:.4f}")
|
| | print(f" LR: {current_lr:.6f}")
|
| |
|
| |
|
| | is_best = val_metrics['mAP'] > best_map
|
| | if is_best:
|
| | best_map = val_metrics['mAP']
|
| |
|
| | if (epoch + 1) % config['training']['save_freq'] == 0 or is_best:
|
| | checkpoint = {
|
| | 'epoch': epoch,
|
| | 'model_state_dict': model.state_dict(),
|
| | 'optimizer_state_dict': optimizer.state_dict(),
|
| | 'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
|
| | 'train_loss': train_loss,
|
| | 'val_metrics': val_metrics,
|
| | 'best_map': best_map,
|
| | 'config': config
|
| | }
|
| |
|
| |
|
| | torch.save(checkpoint, save_dir / 'checkpoint_latest.pth')
|
| |
|
| |
|
| | if is_best:
|
| | torch.save(checkpoint, save_dir / 'checkpoint_best.pth')
|
| | print(f" ✓ Saved best model (mAP: {best_map:.4f})")
|
| |
|
| |
|
| | if (epoch + 1) % config['training']['save_freq'] == 0:
|
| | torch.save(checkpoint, save_dir / f'checkpoint_epoch_{epoch}.pth')
|
| |
|
| | print("\n" + "="*80)
|
| | print(f"Training completed! Best mAP: {best_map:.4f}")
|
| | print(f"Checkpoints saved to: {save_dir}")
|
| |
|
| | writer.close()
|
| |
|
| |
|
| | print("\nEvaluating on test set...")
|
| | test_metrics = validate(model, test_loader, criterion, device, config)
|
| | print(f"\nTest Results:")
|
| | print(f" mAP: {test_metrics['mAP']:.4f}")
|
| | print(f" F1 (micro): {test_metrics['F1_micro']:.4f}")
|
| | print(f" F1 (macro): {test_metrics['F1_macro']:.4f}")
|
| |
|
| |
|
| | with open(save_dir / 'test_results.json', 'w') as f:
|
| | json.dump({k: float(v) if not isinstance(v, list) else v
|
| | for k, v in test_metrics.items()}, f, indent=2)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description='Train TransMIL + Query2Label Hybrid Model')
|
| | parser.add_argument('--config', type=str, default='hybrid_model/config.yaml',
|
| | help='Path to config file')
|
| | parser.add_argument('--resume', type=str, default=None,
|
| | help='Path to checkpoint to resume from')
|
| | args = parser.parse_args()
|
| |
|
| |
|
| | with open(args.config, 'r') as f:
|
| | config = yaml.safe_load(f)
|
| |
|
| | print("="*80)
|
| | print("TransMIL + Query2Label Hybrid Model Training")
|
| | print("="*80)
|
| | print(f"\nConfig: {args.config}")
|
| | if args.resume:
|
| | print(f"Resume from: {args.resume}")
|
| |
|
| |
|
| | train(config, resume_from=args.resume)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|