RobustViT / tokencut_generate_segmentation.py
Hila's picture
init commit
7754b29
import argparse
import torch
import os
import collections
import sys
from tqdm import tqdm
sys.path.append('./TokenCut/model')
sys.path.append('./TokenCut/unsupervised_saliency_detection')
import dino# model
import object_discovery as tokencut
import argparse
import utils
import bilateral_solver
import os
from shutil import copyfile
import PIL.Image as Image
import cv2
import numpy as np
from tqdm import tqdm
from torchvision import transforms
import metric
import matplotlib.pyplot as plt
import skimage
import torch
from tokencut_image_dataset import RobustnessDataset
basewidth = 224
def mask_color_compose(org, mask, mask_color = [173, 216, 230]) :
mask_fg = mask > 0.5
rgb = np.copy(org)
rgb[mask_fg] = (rgb[mask_fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8)
return Image.fromarray(rgb)
# Image transformation applied to all images
ToTensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)),])
def get_tokencut_binary_map(img_pth, backbone,patch_size, tau, resize_size) :
I = Image.open(img_pth).convert('RGB')
I = I.resize(resize_size)
I_resize, w, h, feat_w, feat_h = utils.resize_pil(I, patch_size)
feat = backbone(ToTensor(I_resize).unsqueeze(0).cuda())[0]
seed, bipartition, eigvec = tokencut.ncut(feat, [feat_h, feat_w], [patch_size, patch_size], [h,w], tau)
return bipartition, eigvec
parser = argparse.ArgumentParser(description='Generate Seg maps')
parser.add_argument('--img_path', metavar='path',
help='path to image')
parser.add_argument('--out_dir', type=str, help='output directory')
parser.add_argument('--vit-arch', type=str, default='base', choices=['base', 'small'], help='which architecture')
parser.add_argument('--vit-feat', type=str, default='k', choices=['k', 'q', 'v', 'kqv'], help='which features')
parser.add_argument('--patch-size', type=int, default=16, choices=[16, 8], help='patch size')
parser.add_argument('--tau', type=float, default=0.2, help='Tau for tresholding graph')
parser.add_argument('--sigma-spatial', type=float, default=16, help='sigma spatial in the bilateral solver')
parser.add_argument('--sigma-luma', type=float, default=16, help='sigma luma in the bilateral solver')
parser.add_argument('--sigma-chroma', type=float, default=8, help='sigma chroma in the bilateral solver')
parser.add_argument('--dataset', type=str, default=None, choices=['ECSSD', 'DUTS', 'DUT', None], help='which dataset?')
parser.add_argument('--nb-vis', type=int, default=100, choices=[1, 200], help='nb of visualization')
ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
if __name__ == '__main__':
args = parser.parse_args()
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
feat_dim = 768
args.patch_size = 16
args.vit_arch = 'base'
backbone = dino.ViTFeat(url, feat_dim, args.vit_arch, args.vit_feat, args.patch_size)
msg = 'Load {} pre-trained feature...'.format(args.vit_arch)
print(msg)
backbone.eval()
backbone.cuda()
with torch.no_grad():
# transforms - start
img_pth = args.img_path
img = Image.open(img_pth).convert('RGB')
bipartition, eigvec = get_tokencut_binary_map(img_pth, backbone, args.patch_size, args.tau, img.size)
output_solver, binary_solver = bilateral_solver.bilateral_solver_output(img_pth, bipartition,
sigma_spatial=args.sigma_spatial,
sigma_luma=args.sigma_luma,
sigma_chroma=args.sigma_chroma,
resize_size=img.size)
mask1 = torch.from_numpy(bipartition).cuda()
mask2 = torch.from_numpy(binary_solver).cuda()
if metric.IoU(mask1, mask2) < 0.5:
binary_solver = binary_solver * -1
#output segmented image
img_name = img_pth.split("/")[-1]
out_name = os.path.join(args.out_dir, img_name)
out_lost = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.JPEG'))
out_bfs = os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.JPEG'))
out_gt = os.path.join(args.out_dir, img_name.replace('.JPEG', '_gt.JPEG'))
org = Image.open(img_pth).convert('RGB')
# plt.imsave(fname=out_eigvec, arr=eigvec, cmap='cividis')
mask_color_compose(org, bipartition).save(out_lost)
mask_color_compose(org, binary_solver).save(out_bfs)
#mask_color_compose(org, seg_map).save(out_gt)
torch.save(bipartition, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut.pt')))
torch.save(binary_solver, os.path.join(args.out_dir, img_name.replace('.JPEG', '_tokencut_bfs.pt')))