from argparse import ArgumentParser, Namespace from typing import Dict, List, Tuple import codecs import yaml import numpy as np import cv2 from PIL import Image import torch import torch.nn.functional as F from torchvision.transforms.functional import to_tensor, normalize, resize import gradio as gr from utils import get_model from bilateral_solver import bilateral_solver_output import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") state_dict: dict = torch.hub.load_state_dict_from_url( "https://www.robots.ox.ac.uk/~vgg/research/selfmask/shared_files/selfmask_nq20.pt", map_location=device # "cuda" if torch.cuda.is_available() else "cpu" ) parser = ArgumentParser("SelfMask demo") parser.add_argument( "--config", type=str, default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml" ) args: Namespace = parser.parse_args() base_args = yaml.safe_load(open(f"{args.config}", 'r')) base_args.pop("dataset_name") args: dict = vars(args) args.update(base_args) args: Namespace = Namespace(**args) model = get_model(arch="maskformer", configs=args).to(device) model.load_state_dict(state_dict) model.eval() size: int = 384 max_size: int = 512 mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) std: Tuple[float, float, float] = (0.229, 0.224, 0.225) @torch.no_grad() def main(image: Image): pil_image: Image.Image = resize(image, size=size, max_size=max_size) image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std)) # 3 x H x W dict_outputs = model(image[None].to(device)) batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"] # [0, 1] batch_objectness: torch.Tensor = dict_outputs.get("objectness", None) # [0, 1] if len(batch_pred_masks.shape) == 5: # b x n_layers x n_queries x h x w -> b x n_queries x h x w batch_pred_masks = batch_pred_masks[:, -1, ...] # extract the output from the last decoder layer if batch_objectness is not None: # b x n_layers x n_queries x 1 -> b x n_queries x 1 batch_objectness = batch_objectness[:, -1, ...] # resize prediction to original resolution # note: upsampling by 4 and cutting the padded region allows for a better result H, W = image.shape[-2:] batch_pred_masks = F.interpolate( batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False )[..., :H, :W] # iterate over batch dimension for batch_index, pred_masks in enumerate(batch_pred_masks): # n_queries x 1 -> n_queries objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1) ranks = torch.argsort(objectness, descending=True) # n_queries pred_mask: torch.Tensor = pred_masks[ranks[0]] # H x W pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255 pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask) # float64 pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8) attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB) super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0) return super_imposed_img # return pred_mask_bi demo = gr.Interface( fn=main, inputs=gr.inputs.Image(type="pil", source="upload", tool="editor"), outputs=gr.outputs.Image(type="numpy", label="saliency map"), # "image", examples=[f"resources/{fname}.jpg" for fname in [ "0053", "0236", "0239", "0403", "0412", "ILSVRC2012_test_00005309", "ILSVRC2012_test_00012622", "ILSVRC2012_test_00022698", "ILSVRC2012_test_00040725", "ILSVRC2012_test_00075738", "ILSVRC2012_test_00080683", "ILSVRC2012_test_00085874", "im052", "sun_ainjbonxmervsvpv", "sun_alfntqzssslakmss", "sun_amnrcxhisjfrliwa", "sun_bvyxpvkouzlfwwod" ]], examples_per_page=20, description=codecs.open("description.html", 'r', "utf-8").read(), title="Unsupervised Salient Object Detection with Spectral Cluster Voting", allow_flagging="never", analytics_enabled=False ) demo.launch( # share=True )