| """DiaFoot.AI v2 — Data Composition Ablation. |
| |
| The most important experiment: prove that adding healthy + non-DFU data helps. |
| |
| Trains 3 segmentation models: |
| (a) DFU-only: Train only on DFU images |
| (b) DFU + non-DFU: Train on DFU + non-DFU (current best) |
| (c) All: Train on all three classes (including healthy) |
| |
| Usage: |
| python scripts/run_ablation.py --variant dfu_only --device cuda --epochs 50 |
| python scripts/run_ablation.py --variant dfu_nondfu --device cuda --epochs 50 |
| python scripts/run_ablation.py --variant all --device cuda --epochs 50 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from scripts.train import build_dataloaders |
| from src.models.unetpp import build_unetpp |
| from src.training.losses import DiceCELoss |
| from src.training.schedulers import CosineAnnealingWithWarmup |
| from src.training.trainer import TrainConfig, Trainer |
|
|
| ABLATION_CONFIGS = { |
| "dfu_only": { |
| "classes": ["dfu"], |
| "checkpoint_dir": "checkpoints/ablation_dfu_only", |
| "description": "DFU images only (no negatives)", |
| }, |
| "dfu_nondfu": { |
| "classes": ["dfu", "non_dfu"], |
| "checkpoint_dir": "checkpoints/ablation_dfu_nondfu", |
| "description": "DFU + non-DFU wounds (current approach)", |
| }, |
| "all": { |
| "classes": None, |
| "checkpoint_dir": "checkpoints/ablation_all", |
| "description": "All classes including healthy", |
| }, |
| } |
|
|
|
|
| def main() -> None: |
| """Run data composition ablation.""" |
| parser = argparse.ArgumentParser(description="Data Composition Ablation") |
| parser.add_argument( |
| "--variant", |
| type=str, |
| required=True, |
| choices=list(ABLATION_CONFIGS.keys()), |
| ) |
| parser.add_argument("--splits-dir", type=str, default="data/splits") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--epochs", type=int, default=50) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--num-workers", type=int, default=8) |
| parser.add_argument("--verbose", action="store_true") |
| args = parser.parse_args() |
|
|
| logging.basicConfig( |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| logger = logging.getLogger("ablation") |
|
|
| config = ABLATION_CONFIGS[args.variant] |
| logger.info("Ablation: %s — %s", args.variant, config["description"]) |
|
|
| model = build_unetpp( |
| encoder_name="efficientnet-b4", |
| encoder_weights="imagenet", |
| classes=1, |
| decoder_attention_type="scse", |
| ) |
|
|
| train_loader, val_loader = build_dataloaders( |
| args.splits_dir, |
| args.batch_size, |
| args.num_workers, |
| filter_classes=config["classes"], |
| ) |
| logger.info( |
| "Data: %d train, %d val batches", |
| len(train_loader), |
| len(val_loader), |
| ) |
|
|
| loss_fn = DiceCELoss() |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=1e-4, |
| weight_decay=1e-2, |
| ) |
| scheduler = CosineAnnealingWithWarmup( |
| optimizer, |
| warmup_epochs=5, |
| max_epochs=args.epochs, |
| ) |
|
|
| torch.manual_seed(42) |
|
|
| trainer_config = TrainConfig( |
| epochs=args.epochs, |
| precision="bf16-mixed", |
| compile_model=False, |
| gradient_clip=1.0, |
| checkpoint_dir=config["checkpoint_dir"], |
| monitor_metric="val/loss", |
| monitor_mode="min", |
| device=args.device, |
| early_stopping_patience=15, |
| ) |
|
|
| trainer = Trainer(model=model, config=trainer_config) |
| trainer.fit(train_loader, val_loader, loss_fn, optimizer, scheduler) |
| logger.info("Ablation %s complete.", args.variant) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|