File size: 1,890 Bytes
959a00c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import albumentations as A
import cv2
import torch
import tqdm
from albumentations.pytorch.functional import img_to_tensor
from pathlib import Path
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.utils import draw_segmentation_masks, make_grid, save_image

import utils.misc as misc
from models import get_ensemble_model
from opt import get_opt


def demo(folder_path, output_path=Path("tmp")):
    opt = get_opt()
    model = get_ensemble_model(opt).to(opt.device)
    misc.resume_from(model, opt.resume)

    with torch.no_grad():
        for image_path in tqdm.tqdm(folder_path.glob("*.jpg")):
            image = cv2.imread(image_path.as_posix())
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            dsm_image = torch.from_numpy(image).permute(2, 0, 1)
            image_size = image.shape[:2]
            raw_image = img_to_tensor(image)
            image = img_to_tensor(
                image,
                normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
            )
            image = image.to(opt.device).unsqueeze(0)
            outputs = model(image, seg_size=image_size)
            out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()

            overlay = draw_segmentation_masks(
                dsm_image, masks=out_map[0, ...] > opt.mask_threshold
            )
            grid_image = make_grid(
                [
                    raw_image,
                    (out_map.repeat(3, 1, 1) > opt.mask_threshold).float() * 255,
                    overlay / 255.0,
                ],
                padding=5,
            )
            save_image(grid_image, (output_path / image_path.name).as_posix())


if __name__ == "__main__":
    folder_path = Path("demo")
    output_path = Path("tmp")
    output_path.mkdir(exist_ok=True, parents=True)
    demo(folder_path)