File size: 4,305 Bytes
35188e4
 
7b03ec2
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3dac8e
35188e4
200320e
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b03ec2
 
 
 
 
35188e4
 
7b03ec2
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b03ec2
35188e4
 
7b03ec2
 
35188e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b03ec2
 
35188e4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_model
from bilateral_solver import bilateral_solver_output
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict: dict = torch.hub.load_state_dict_from_url(
    "https://www.robots.ox.ac.uk/~vgg/research/selfmask/shared_files/selfmask_nq20.pt",
    map_location=device  # "cuda" if torch.cuda.is_available() else "cpu"
)

parser = ArgumentParser("SelfMask demo")
parser.add_argument(
    "--config",
    type=str,
    default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
)

args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)

model = get_model(arch="maskformer", configs=args).to(device)
model.load_state_dict(state_dict)
model.eval()

size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)


@torch.no_grad()
def main(image: Image):
    pil_image: Image.Image = resize(image, size=size, max_size=max_size)
    image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std))  # 3 x H x W
    dict_outputs = model(image[None].to(device))

    batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"]  # [0, 1]
    batch_objectness: torch.Tensor = dict_outputs.get("objectness", None)  # [0, 1]

    if len(batch_pred_masks.shape) == 5:
        # b x n_layers x n_queries x h x w -> b x n_queries x h x w
        batch_pred_masks = batch_pred_masks[:, -1, ...]  # extract the output from the last decoder layer

        if batch_objectness is not None:
            # b x n_layers x n_queries x 1 -> b x n_queries x 1
            batch_objectness = batch_objectness[:, -1, ...]

    # resize prediction to original resolution
    # note: upsampling by 4 and cutting the padded region allows for a better result
    H, W = image.shape[-2:]
    batch_pred_masks = F.interpolate(
        batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
    )[..., :H, :W]

    # iterate over batch dimension
    for batch_index, pred_masks in enumerate(batch_pred_masks):
        # n_queries x 1 -> n_queries
        objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
        ranks = torch.argsort(objectness, descending=True)  # n_queries
        pred_mask: torch.Tensor = pred_masks[ranks[0]]  # H x W
    pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255

    pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask)  # float64
    pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)

    attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
    super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
    return super_imposed_img
    # return pred_mask_bi


demo = gr.Interface(
    fn=main,
    inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
    outputs=gr.outputs.Image(type="numpy", label="saliency map"),  # "image",
    examples=[f"resources/{fname}.jpg" for fname in [
        "0053",
        "0236",
        "0239",
        "0403",
        "0412",
        "ILSVRC2012_test_00005309",
        "ILSVRC2012_test_00012622",
        "ILSVRC2012_test_00022698",
        "ILSVRC2012_test_00040725",
        "ILSVRC2012_test_00075738",
        "ILSVRC2012_test_00080683",
        "ILSVRC2012_test_00085874",
        "im052",
        "sun_ainjbonxmervsvpv",
        "sun_alfntqzssslakmss",
        "sun_amnrcxhisjfrliwa",
        "sun_bvyxpvkouzlfwwod"
    ]],
    examples_per_page=20,
    description=codecs.open("description.html", 'r', "utf-8").read(),
    title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
    allow_flagging="never",
    analytics_enabled=False
)

demo.launch(
    # share=True
)