File size: 3,656 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Per-architecture efficiency table: #params, FLOPs (GMac), inference throughput.

Representative setting (in_channels=3, num_classes=2). FLOPs via thop -> fvcore ->
ptflops (whichever is installed); params and throughput always computed. Run once,
on a GPU (A100). Output: results/<exp>/efficiency.{md,csv}.

  python framework/efficiency.py --img_size 256 --out_root results --exp_name baselines
"""
from __future__ import annotations

import os
import sys
import time
import json
import argparse

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch

from framework.models.registry import build_model, required_img_size

ARCHS = ["unet", "unetpp", "deeplabv3plus", "attention_unet", "transunet", "swinunet"]


def count_flops_gmac(model, x):
    try:
        from thop import profile
        macs, _ = profile(model, inputs=(x,), verbose=False)
        return macs / 1e9
    except Exception:
        pass
    try:
        from fvcore.nn import FlopCountAnalysis
        return FlopCountAnalysis(model, x).total() / 1e9
    except Exception:
        pass
    return float("nan")


@torch.no_grad()
def throughput(model, x, iters=50, warmup=10):
    for _ in range(warmup):
        model(x)
    if x.is_cuda:
        torch.cuda.synchronize()
    t0 = time.time()
    for _ in range(iters):
        model(x)
    if x.is_cuda:
        torch.cuda.synchronize()
    dt = time.time() - t0
    return iters * x.size(0) / dt          # images / sec


def encoder_for(arch):
    if arch in ("unet", "unetpp", "deeplabv3plus"):
        return "resnet50"
    if arch == "transunet":
        return "R50-ViT-B_16"
    return "resnet34"


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--img_size", type=int, default=256)
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--out_root", default="results")
    ap.add_argument("--exp_name", default="baselines")
    args = ap.parse_args()

    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rows = []
    for arch in ARCHS:
        sz = required_img_size(arch) or args.img_size
        model = build_model(arch, in_channels=3, num_classes=2, img_size=sz,
                            encoder=encoder_for(arch), encoder_weights="none").to(dev).eval()
        params_m = sum(p.numel() for p in model.parameters()) / 1e6
        x1 = torch.randn(1, 3, sz, sz, device=dev)
        gmac = count_flops_gmac(model, x1)
        xb = torch.randn(args.batch_size, 3, sz, sz, device=dev)
        try:
            ips = throughput(model, xb)
        except Exception as e:
            ips = float("nan"); print(f"[warn] throughput {arch}: {e}")
        rows.append({"arch": arch, "img": sz, "params_M": round(params_m, 2),
                     "gmac": round(gmac, 2) if gmac == gmac else None,
                     "imgs_per_s": round(ips, 1) if ips == ips else None})
        print(f"{arch:16s} img={sz} params={params_m:.2f}M GMac={gmac:.2f} {ips:.1f} img/s")
        del model, x1, xb
        if dev.type == "cuda":
            torch.cuda.empty_cache()

    base = os.path.join(args.out_root, args.exp_name)
    os.makedirs(base, exist_ok=True)
    with open(os.path.join(base, "efficiency.json"), "w") as f:
        json.dump(rows, f, indent=2)
    md = "| Method | Img | Params(M) | GMac | Img/s |\n|---|---|---|---|---|\n"
    for r in rows:
        md += f"| {r['arch']} | {r['img']} | {r['params_M']} | {r['gmac']} | {r['imgs_per_s']} |\n"
    open(os.path.join(base, "efficiency.md"), "w").write(md)
    print(md)
    print(f"written {base}/efficiency.{{json,md}}")


if __name__ == "__main__":
    main()