File size: 5,150 Bytes
7754b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse

import torch
import os
import collections

import sys
from tqdm import tqdm

sys.path.append('./TokenCut/model')
sys.path.append('./TokenCut/unsupervised_saliency_detection')
import dino# model

import object_discovery as tokencut
import argparse
import utils
import bilateral_solver
import os

from shutil import copyfile
import PIL.Image as Image
import cv2
import numpy as np
from tqdm import tqdm

from torchvision import transforms
import metric
import matplotlib.pyplot as plt
import skimage
import torch

from tokencut_image_dataset import RobustnessDataset

basewidth = 224


def mask_color_compose(org, mask, mask_color = [173, 216, 230]) :

    mask_fg = mask > 0.5
    rgb = np.copy(org)
    rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)

    return Image.fromarray(rgb)

# Image transformation applied to all images
ToTensor = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.485, 0.456, 0.406),
                                                     (0.229, 0.224, 0.225)),])

def get_tokencut_binary_map(img_pth, backbone,patch_size, tau, resize_size) :


    I = Image.open(img_pth).convert('RGB')
    I = I.resize(resize_size)

    I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size)

    feat = backbone(ToTensor(I_resize).unsqueeze(0).cuda())[0]

    seed, bipartition, eigvec = tokencut.ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau)
    return bipartition, eigvec

parser = argparse.ArgumentParser(description='Generate Seg maps')
parser.add_argument('--img_path', metavar='path',
                    help='path to image')

parser.add_argument('--out_dir', type=str, help='output directory')

parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture')

parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')

parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size')

parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph')

parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver')

parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver')

parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver')


parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?')

parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization')




ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))


if __name__ == '__main__':


    args = parser.parse_args()

    url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
    feat_dim = 768
    args.patch_size = 16
    args.vit_arch = 'base'

    backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size)
    msg = 'Load {} pre-trained feature...'.format(args.vit_arch)
    print(msg)
    backbone.eval()
    backbone.cuda()

    with torch.no_grad():
        # transforms - start
        img_pth = args.img_path
        img = Image.open(img_pth).convert('RGB')


        bipartition, eigvec = get_tokencut_binary_map(img_pth, backbone, args.patch_size, args.tau, img.size)
        output_solver, binary_solver = bilateral_solver.bilateral_solver_output(img_pth, bipartition,
                                                                                sigma_spatial=args.sigma_spatial,
                                                                                sigma_luma=args.sigma_luma,
                                                                                sigma_chroma=args.sigma_chroma,
                                                                                resize_size=img.size)
        mask1 = torch.from_numpy(bipartition).cuda()
        mask2 = torch.from_numpy(binary_solver).cuda()
        if metric.IoU(mask1, mask2) < 0.5:
            binary_solver = binary_solver * -1




        #output segmented image
        img_name = img_pth.split("/")[-1]
        out_name = os.path.join(args.out_dir, img_name)
        out_lost = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.JPEG'))
        out_bfs = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.JPEG'))
        out_gt = os.path.join(args.out_dir, img_name.replace('.JPEG', '_gt.JPEG'))

        org = Image.open(img_pth).convert('RGB')
        # plt.imsave(fname=out_eigvec, arr=eigvec, cmap='cividis')
        mask_color_compose(org, bipartition).save(out_lost)
        mask_color_compose(org, binary_solver).save(out_bfs)
        #mask_color_compose(org, seg_map).save(out_gt)


        torch.save(bipartition, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.pt')))
        torch.save(binary_solver, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.pt')))