selfmask / app.py
noelshin's picture
Update app.py
200320e
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import codecs
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_model
from bilateral_solver import bilateral_solver_output
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict: dict = torch.hub.load_state_dict_from_url(
"https://www.robots.ox.ac.uk/~vgg/research/selfmask/shared_files/selfmask_nq20.pt",
map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
)
parser = ArgumentParser("SelfMask demo")
parser.add_argument(
"--config",
type=str,
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
)
args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)
model = get_model(arch="maskformer", configs=args).to(device)
model.load_state_dict(state_dict)
model.eval()
size: int = 384
max_size: int = 512
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
@torch.no_grad()
def main(image: Image):
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
dict_outputs = model(image[None].to(device))
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
if len(batch_pred_masks.shape) == 5:
# b x n_layers x n_queries x h x w -> b x n_queries x h x w
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
if batch_objectness is not None:
# b x n_layers x n_queries x 1 -> b x n_queries x 1
batch_objectness = batch_objectness[:, -1, ...]
# resize prediction to original resolution
# note: upsampling by 4 and cutting the padded region allows for a better result
H, W = image.shape[-2:]
batch_pred_masks = F.interpolate(
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
)[..., :H, :W]
# iterate over batch dimension
for batch_index, pred_masks in enumerate(batch_pred_masks):
# n_queries x 1 -> n_queries
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
ranks = torch.argsort(objectness, descending=True) # n_queries
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
return super_imposed_img
# return pred_mask_bi
demo = gr.Interface(
fn=main,
inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"),
outputs=gr.outputs.Image(type="numpy", label="saliency map"), # "image",
examples=[f"resources/{fname}.jpg" for fname in [
"0053",
"0236",
"0239",
"0403",
"0412",
"ILSVRC2012_test_00005309",
"ILSVRC2012_test_00012622",
"ILSVRC2012_test_00022698",
"ILSVRC2012_test_00040725",
"ILSVRC2012_test_00075738",
"ILSVRC2012_test_00080683",
"ILSVRC2012_test_00085874",
"im052",
"sun_ainjbonxmervsvpv",
"sun_alfntqzssslakmss",
"sun_amnrcxhisjfrliwa",
"sun_bvyxpvkouzlfwwod"
]],
examples_per_page=20,
description=codecs.open("description.html", 'r', "utf-8").read(),
title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
allow_flagging="never",
analytics_enabled=False
)
demo.launch(
# share=True
)