Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import os | |
try: | |
import detectron2 | |
except: | |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git') | |
import logging | |
logging.disable(logging.CRITICAL) # comment out to enable verbose logging | |
######################################################### | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import PIL.Image as Image | |
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from collections import defaultdict | |
from pathlib import Path | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms as T | |
from tqdm import tqdm | |
from types import SimpleNamespace | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.data import MetadataCatalog | |
from detectron2.utils.visualizer import Visualizer | |
import config | |
import utils as ut | |
from eval_utils import MaskMerger | |
from mask_former_trainer import setup, Trainer | |
def load_model_cfg(dataset=None): | |
args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True) | |
cfg = setup(args) | |
cfg.defrost() | |
cfg.MODEL.DEVICE = 'cpu' | |
cfg.freeze() | |
random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE)) | |
model = Trainer.build_model(cfg) | |
checkpointer = DetectionCheckpointer(model, | |
random_state=random_state, | |
save_dir=None) | |
checkpoint_path = 'checkpoints/checkpoint_best.pth' | |
checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False) | |
model.eval() | |
return model, cfg | |
def edgeness(masks): | |
em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device) | |
lm = em.clone() | |
lm[..., :2] = 1. | |
rm = em.clone() | |
rm[...,-2:] = 1. | |
tm = em.clone() | |
tm[..., :2, :] = 1. | |
bm = em.clone() | |
bm[..., -2:,:] = 1. | |
one = torch.tensor(1.,dtype= masks.dtype, device=masks.device) | |
l = (masks * lm).flatten(-2).sum(-1) / lm.sum() | |
l = torch.where(l > 0.3, one, l) | |
r = (masks * rm).flatten(-2).sum(-1) / rm.sum() | |
r = torch.where(r > 0.3, one, r) | |
t = (masks * tm).flatten(-2).sum(-1) / tm.sum() | |
t = torch.where(t > 0.3, one, t) | |
b = (masks * bm).flatten(-2).sum(-1) / bm.sum() | |
b = torch.where(b > 0.3, one, b) | |
return (l + r + t + b ) | |
def expand2sizedivisible(pil_img, background_color, size_divisibility): | |
width, height = pil_img.size | |
if width % size_divisibility == 0 and height % size_divisibility == 0: | |
return pil_img | |
result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color) | |
result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2)) | |
return result | |
def cropfromsizedivisible(img, size_divisibility, orig_size): | |
height, width = img.shape[:2] | |
owidth, oheight = orig_size | |
result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2] | |
return result | |
def evaluate_image(image_path): | |
binary_threshold = 0.5 | |
metadata = MetadataCatalog.get("__unused") | |
model, cfg = load_model_cfg("DAVIS") | |
merger = MaskMerger(cfg, model, merger_model="dino_vitb8") | |
image_pil = Image.open(image_path).convert('RGB') | |
image_pil.thumbnail((384, 384)) | |
osize = image_pil.size | |
if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0: | |
image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY) | |
image = np.asarray(image_pil) | |
image_pt = torch.from_numpy(np.array(image)).permute(2,0,1) | |
with torch.no_grad(): | |
sample = [{'rgb': image_pt}] | |
preds = model.forward_base(sample, keys=['rgb'], get_eval=True) | |
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0) | |
K = masks_raw.shape[1] | |
if K > 2: | |
masks_softmaxed = torch.softmax(masks_raw, dim=1) | |
masks_dict = merger(sample, masks_softmaxed) | |
K = 2 | |
masks = masks_dict['cos'] | |
else: | |
print(K) | |
masks = masks_raw.softmax(1) | |
masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w | |
bg = edgeness(masks_raw)[0].argmax().item() | |
masks = masks_raw[0] > binary_threshold | |
frame_visualizer = Visualizer(image, metadata) | |
out = frame_visualizer.overlay_instances( | |
masks=masks[[int(bg==0)]], | |
alpha=0.3, | |
assigned_colors=[(1,0,1)] | |
).get_image() | |
return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize) | |
paths = sorted(pathlib.Path('samples').glob('*.jpg')) | |
css = ".component-1 {height: 256px !important;}" | |
demo = gr.Interface( | |
fn=evaluate_image, | |
inputs=gr.Image(label='Image', type='filepath'), | |
outputs=gr.Image(label='Annotated Image', type='numpy'), | |
examples=[[path.as_posix(), 0.15, 6] for path in paths], | |
title="Guess What Moves", | |
description="#### Unsupervised Image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)", | |
css=css) | |
demo.queue().launch() | |