#!/usr/bin/env python import os try: import detectron2 except: os.system('pip install git+https://github.com/facebookresearch/detectron2.git') import logging logging.disable(logging.CRITICAL) # comment out to enable verbose logging ######################################################### import pathlib import gradio as gr import numpy as np import PIL.Image as Image import os import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt from PIL import Image from collections import defaultdict from pathlib import Path from torch.utils.data import Dataset, DataLoader from torchvision import transforms as T from tqdm import tqdm from types import SimpleNamespace from detectron2.checkpoint import DetectionCheckpointer from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import Visualizer import config import utils as ut from eval_utils import MaskMerger from mask_former_trainer import setup, Trainer def load_model_cfg(dataset=None): args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["MODEL.DEVICE", "cpu", "GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True) cfg = setup(args) random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE)) model = Trainer.build_model(cfg) checkpointer = DetectionCheckpointer(model, random_state=random_state, save_dir=None) checkpoint_path = 'checkpoints/checkpoint_best.pth' checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False) model.eval() return model, cfg def edgeness(masks): em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device) lm = em.clone() lm[..., :2] = 1. rm = em.clone() rm[...,-2:] = 1. tm = em.clone() tm[..., :2, :] = 1. bm = em.clone() bm[..., -2:,:] = 1. one = torch.tensor(1.,dtype= masks.dtype, device=masks.device) l = (masks * lm).flatten(-2).sum(-1) / lm.sum() l = torch.where(l > 0.3, one, l) r = (masks * rm).flatten(-2).sum(-1) / rm.sum() r = torch.where(r > 0.3, one, r) t = (masks * tm).flatten(-2).sum(-1) / tm.sum() t = torch.where(t > 0.3, one, t) b = (masks * bm).flatten(-2).sum(-1) / bm.sum() b = torch.where(b > 0.3, one, b) return (l + r + t + b ) def expand2sizedivisible(pil_img, background_color, size_divisibility): width, height = pil_img.size if width % size_divisibility == 0 and height % size_divisibility == 0: return pil_img result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color) result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2)) return result def cropfromsizedivisible(img, size_divisibility, orig_size): height, width = img.shape[:2] owidth, oheight = orig_size result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2] return result def evaluate_image(image_path): binary_threshold = 0.5 metadata = MetadataCatalog.get("__unused") image_pil = Image.open(image_path).convert('RGB') image_pil.thumbnail((384, 384)) osize = image_pil.size if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0: image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY) image = np.asarray(image_pil) image_pt = torch.from_numpy(np.array(image)).permute(2,0,1) with torch.no_grad(): sample = [{'rgb': image_pt}] preds = model.forward_base(sample, keys=['rgb'], get_eval=True) masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) K = masks_raw.shape[1] if K > 2: masks_softmaxed = torch.softmax(masks_raw, dim=1) masks_dict = merger(sample, masks_softmaxed) K = 2 masks = masks_dict['cos'] else: print(K) masks = masks_raw.softmax(1) masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w bg = edgeness(masks_raw)[0].argmax().item() masks = masks_raw[0] > binary_threshold frame_visualizer = Visualizer(image, metadata) out = frame_visualizer.overlay_instances( masks=masks[[int(bg==0)]], alpha=0.3, assigned_colors=[(1,0,1)] ).get_image() return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize) paths = sorted(pathlib.Path('samples').glob('*.jpg')) css = ".component-1 {height: 256px !important;}" model, cfg = load_model_cfg("DAVIS") merger = MaskMerger(cfg, model, merger_model="dino_vitb8") demo = gr.Interface( fn=evaluate_image, inputs=gr.Image(label='Image', type='filepath'), outputs=gr.Image(label='Annotated Image', type='numpy'), examples=[[path.as_posix(), 0.15, 6] for path in paths], title="Guess What Moves", description="#### Unsupervised image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)", css=css) demo.queue().launch()