File size: 1,396 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Evaluation entrypoint (single process).

  python framework/test.py --dataset cvc_clinicdb --arch unet --exp_name myrun --seed 0
"""
from __future__ import annotations

import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch
import cv2

# Single-threaded OpenCV per process (parallelism via num_workers); avoids the
# nproc-sized cv2 thread-pool oversubscription that starves the GPU at high res.
cv2.setNumThreads(1)

from framework.config import Config
from framework.models.registry import build_model, required_img_size
from framework.engine.evaluator import evaluate
from framework.data.loaders import build_dataset


def main():
    cfg = Config.from_args()
    req = required_img_size(cfg.arch)
    if req and cfg.img_size != req:
        cfg.img_size = req

    if torch.cuda.is_available():
        torch.cuda.set_device(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    probe = build_dataset(cfg, "test")
    model = build_model(cfg.arch, in_channels=probe.in_channels, num_classes=probe.num_classes,
                        img_size=cfg.img_size, encoder=cfg.encoder,
                        encoder_weights="none",  # weights come from checkpoint
                        pretrained_ckpt="")
    evaluate(cfg, model, device, ckpt_path=cfg.resume)


if __name__ == "__main__":
    main()