| """Training entrypoint. |
| |
| Single GPU: python framework/train.py --dataset cvc_clinicdb --arch unet ... |
| Multi-GPU : torchrun --nproc_per_node=4 framework/train.py --dataset ... --arch ... |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
|
|
| |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
| import torch |
| import cv2 |
|
|
| |
| |
| |
| cv2.setNumThreads(1) |
|
|
| from framework.config import Config |
| from framework.engine.distributed import setup_distributed, cleanup_distributed, set_seed, print_main |
| from framework.models.registry import build_model, required_img_size |
| from framework.engine.trainer import Trainer |
|
|
|
|
| def main(): |
| cfg = Config.from_args() |
|
|
| |
| req = required_img_size(cfg.arch) |
| if req and cfg.img_size != req: |
| print_main(f"[info] arch '{cfg.arch}' requires img_size={req}; overriding {cfg.img_size}.") |
| cfg.img_size = req |
|
|
| local_rank = setup_distributed() |
| set_seed(cfg.seed, rank=local_rank) |
|
|
| |
| from framework.data.loaders import build_dataset |
| probe = build_dataset(cfg, "train") |
| in_ch, n_cls = probe.in_channels, probe.num_classes |
| print_main(f"[data] {cfg.dataset}/{cfg.protocol}: in_channels={in_ch} num_classes={n_cls} " |
| f"train={len(probe)}") |
|
|
| model = build_model(cfg.arch, in_channels=in_ch, num_classes=n_cls, |
| img_size=cfg.img_size, encoder=cfg.encoder, |
| encoder_weights=cfg.encoder_weights, |
| pretrained_ckpt=cfg.pretrained_ckpt) |
| print_main(f"[model] {cfg.arch} params={sum(p.numel() for p in model.parameters())/1e6:.1f}M " |
| f"amp={cfg.amp}") |
|
|
| trainer = Trainer(cfg, model, local_rank) |
| trainer.fit() |
| cleanup_distributed() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|