from argparse import ArgumentParser, Namespace from typing import Dict, List, Tuple import yaml import numpy as np import cv2 from PIL import Image import torch import torch.nn.functional as F from torchvision.transforms.functional import to_tensor, normalize, resize import gradio as gr from utils import get_model from bilateral_solver import bilateral_solver_output import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") state_dict: dict = torch.hub.load_state_dict_from_url( "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt", map_location=device # "cuda" if torch.cuda.is_available() else "cpu" )["model"] parser = ArgumentParser("SelfMask demo") parser.add_argument( "--config", type=str, default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml" ) # parser.add_argument( # "--p_state_dict", # type=str, # default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt", # ) # # parser.add_argument( # "--dataset_name", '-dn', type=str, default="duts", # choices=["dut_omron", "duts", "ecssd"] # ) # independent variables # parser.add_argument("--use_gpu", type=bool, default=True) # parser.add_argument('--seed', default=0, type=int) # parser.add_argument("--dir_root", type=str, default="..") # parser.add_argument("--gpu_id", type=int, default=2) # parser.add_argument("--suffix", type=str, default='') args: Namespace = parser.parse_args() base_args = yaml.safe_load(open(f"{args.config}", 'r')) base_args.pop("dataset_name") args: dict = vars(args) args.update(base_args) args: Namespace = Namespace(**args) model = get_model(arch="maskformer", configs=args).to(device) model.load_state_dict(state_dict) model.eval() @torch.no_grad() def main( image: Image.Image, size: int = 384, max_size: int = 512, mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), std: Tuple[float, float, float] = (0.229, 0.224, 0.225) ): pil_image: Image.Image = resize(image, size=size, max_size=max_size) image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W dict_outputs = model(image[None].to(device)) batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1] batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1] if len(batch_pred_masks.shape) == 5: # b x n_layers x n_queries x h x w -> b x n_queries x h x w batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer if batch_objectness is not None: # b x n_layers x n_queries x 1 -> b x n_queries x 1 batch_objectness = batch_objectness[:, -1, ...] # resize prediction to original resolution # note: upsampling by 4 and cutting the padded region allows for a better result H, W = image.shape[-2:] batch_pred_masks = F.interpolate( batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False )[..., :H, :W] # iterate over batch dimension for batch_index, pred_masks in enumerate(batch_pred_masks): # n_queries x 1 -> n_queries objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1) ranks = torch.argsort(objectness, descending=True) # n_queries pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64 pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) return super_imposed_img # return pred_mask_bi demo = gr.Interface( fn=main, inputs=gr.inputs.Image(type="pil"), outputs="image", examples=[f"resources/{fname}.jpg" for fname in [ "0053", "0236", "0239", "0403", "0412", "ILSVRC2012_test_00005309", "ILSVRC2012_test_00012622", "ILSVRC2012_test_00022698", "ILSVRC2012_test_00040725", "ILSVRC2012_test_00075738", "ILSVRC2012_test_00080683", "ILSVRC2012_test_00085874", "im052", "sun_ainjbonxmervsvpv", "sun_alfntqzssslakmss", "sun_amnrcxhisjfrliwa", "sun_bvyxpvkouzlfwwod" ]], title="Unsupervised Salient Object Detection with Spectral Cluster Voting", allow_flagging="never", analytics_enabled=False ) demo.launch( # share=True )