MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
2.2 kB
"""Training entrypoint.
Single GPU: python framework/train.py --dataset cvc_clinicdb --arch unet ...
Multi-GPU : torchrun --nproc_per_node=4 framework/train.py --dataset ... --arch ...
"""
from __future__ import annotations
import os
import sys
# allow `python framework/train.py` (add repo root to path)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
import cv2
# Each DataLoader worker single-threaded for OpenCV; parallelism comes from num_workers.
# Without this, cv2 spawns an nproc-sized (~384) thread pool per worker, whose per-op
# dispatch overhead starves the GPU at high resolution (768) -> ~4x slower epochs.
cv2.setNumThreads(1)
from framework.config import Config
from framework.engine.distributed import setup_distributed, cleanup_distributed, set_seed, print_main
from framework.models.registry import build_model, required_img_size
from framework.engine.trainer import Trainer
def main():
cfg = Config.from_args()
# some backbones require a fixed input size
req = required_img_size(cfg.arch)
if req and cfg.img_size != req:
print_main(f"[info] arch '{cfg.arch}' requires img_size={req}; overriding {cfg.img_size}.")
cfg.img_size = req
local_rank = setup_distributed()
set_seed(cfg.seed, rank=local_rank)
# peek dataset to get in/out channels before building the model
from framework.data.loaders import build_dataset
probe = build_dataset(cfg, "train")
in_ch, n_cls = probe.in_channels, probe.num_classes
print_main(f"[data] {cfg.dataset}/{cfg.protocol}: in_channels={in_ch} num_classes={n_cls} "
f"train={len(probe)}")
model = build_model(cfg.arch, in_channels=in_ch, num_classes=n_cls,
img_size=cfg.img_size, encoder=cfg.encoder,
encoder_weights=cfg.encoder_weights,
pretrained_ckpt=cfg.pretrained_ckpt)
print_main(f"[model] {cfg.arch} params={sum(p.numel() for p in model.parameters())/1e6:.1f}M "
f"amp={cfg.amp}")
trainer = Trainer(cfg, model, local_rank)
trainer.fit()
cleanup_distributed()
if __name__ == "__main__":
main()