selfmask / app.py
noelshin's picture
Add application file
35188e4
raw
history blame
4.8 kB
from argparse import ArgumentParser, Namespace
from typing import Dict, List, Tuple
import yaml
import numpy as np
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, normalize, resize
import gradio as gr
from utils import get_model
from bilateral_solver import bilateral_solver_output
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict: dict = torch.hub.load_state_dict_from_url(
"https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
map_location=device # "cuda" if torch.cuda.is_available() else "cpu"
)["model"]
parser = ArgumentParser("SelfMask demo")
parser.add_argument(
"--config",
type=str,
default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
)
# parser.add_argument(
# "--p_state_dict",
# type=str,
# default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
# )
#
# parser.add_argument(
# "--dataset_name", '-dn', type=str, default="duts",
# choices=["dut_omron", "duts", "ecssd"]
# )
# independent variables
# parser.add_argument("--use_gpu", type=bool, default=True)
# parser.add_argument('--seed', default=0, type=int)
# parser.add_argument("--dir_root", type=str, default="..")
# parser.add_argument("--gpu_id", type=int, default=2)
# parser.add_argument("--suffix", type=str, default='')
args: Namespace = parser.parse_args()
base_args = yaml.safe_load(open(f"{args.config}", 'r'))
base_args.pop("dataset_name")
args: dict = vars(args)
args.update(base_args)
args: Namespace = Namespace(**args)
model = get_model(arch="maskformer", configs=args).to(device)
model.load_state_dict(state_dict)
model.eval()
@torch.no_grad()
def main(
image: Image.Image,
size: int = 384,
max_size: int = 512,
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
):
pil_image: Image.Image = resize(image, size=size, max_size=max_size)
image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W
dict_outputs = model(image[None].to(device))
batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1]
batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1]
if len(batch_pred_masks.shape) == 5:
# b x n_layers x n_queries x h x w -> b x n_queries x h x w
batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer
if batch_objectness is not None:
# b x n_layers x n_queries x 1 -> b x n_queries x 1
batch_objectness = batch_objectness[:, -1, ...]
# resize prediction to original resolution
# note: upsampling by 4 and cutting the padded region allows for a better result
H, W = image.shape[-2:]
batch_pred_masks = F.interpolate(
batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
)[..., :H, :W]
# iterate over batch dimension
for batch_index, pred_masks in enumerate(batch_pred_masks):
# n_queries x 1 -> n_queries
objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
ranks = torch.argsort(objectness, descending=True) # n_queries
pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W
pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64
pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
return super_imposed_img
# return pred_mask_bi
demo = gr.Interface(
fn=main,
inputs=gr.inputs.Image(type="pil"),
outputs="image",
examples=[f"resources/{fname}.jpg" for fname in [
"0053",
"0236",
"0239",
"0403",
"0412",
"ILSVRC2012_test_00005309",
"ILSVRC2012_test_00012622",
"ILSVRC2012_test_00022698",
"ILSVRC2012_test_00040725",
"ILSVRC2012_test_00075738",
"ILSVRC2012_test_00080683",
"ILSVRC2012_test_00085874",
"im052",
"sun_ainjbonxmervsvpv",
"sun_alfntqzssslakmss",
"sun_amnrcxhisjfrliwa",
"sun_bvyxpvkouzlfwwod"
]],
title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
allow_flagging="never",
analytics_enabled=False
)
demo.launch(
# share=True
)