subhc's picture
Update app.py
d08ffbd
#!/usr/bin/env python
import os
try:
import detectron2
except:
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
import logging
logging.disable(logging.CRITICAL) # comment out to enable verbose logging
#########################################################
import pathlib
import gradio as gr
import numpy as np
import PIL.Image as Image
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from collections import defaultdict
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from tqdm import tqdm
from types import SimpleNamespace
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
import config
import utils as ut
from eval_utils import MaskMerger
from mask_former_trainer import setup, Trainer
def load_model_cfg(dataset=None):
args = SimpleNamespace(config_file='configs/maskformer/maskformer_R50_bs16_160k_dino.yaml', opts=["MODEL.DEVICE", "cpu", "GWM.DATASET", dataset], wandb_sweep_mode=False, resume_path=str('checkpoints/checkpoint_best.pth'), eval_only=True)
cfg = setup(args)
random_state = ut.random_state.PytorchRNGState(seed=cfg.SEED).to(torch.device(cfg.MODEL.DEVICE))
model = Trainer.build_model(cfg)
checkpointer = DetectionCheckpointer(model,
random_state=random_state,
save_dir=None)
checkpoint_path = 'checkpoints/checkpoint_best.pth'
checkpoint = checkpointer.resume_or_load(checkpoint_path, resume=False)
model.eval()
return model, cfg
def edgeness(masks):
em = torch.zeros(1, masks.shape[-2], masks.shape[-1], device=masks.device)
lm = em.clone()
lm[..., :2] = 1.
rm = em.clone()
rm[...,-2:] = 1.
tm = em.clone()
tm[..., :2, :] = 1.
bm = em.clone()
bm[..., -2:,:] = 1.
one = torch.tensor(1.,dtype= masks.dtype, device=masks.device)
l = (masks * lm).flatten(-2).sum(-1) / lm.sum()
l = torch.where(l > 0.3, one, l)
r = (masks * rm).flatten(-2).sum(-1) / rm.sum()
r = torch.where(r > 0.3, one, r)
t = (masks * tm).flatten(-2).sum(-1) / tm.sum()
t = torch.where(t > 0.3, one, t)
b = (masks * bm).flatten(-2).sum(-1) / bm.sum()
b = torch.where(b > 0.3, one, b)
return (l + r + t + b )
def expand2sizedivisible(pil_img, background_color, size_divisibility):
width, height = pil_img.size
if width % size_divisibility == 0 and height % size_divisibility == 0:
return pil_img
result = Image.new(pil_img.mode, (width + (size_divisibility - width%size_divisibility)%size_divisibility, height + (size_divisibility - height%size_divisibility)%size_divisibility), background_color)
result.paste(pil_img, (((size_divisibility - width%size_divisibility)%size_divisibility) // 2, ((size_divisibility - height%size_divisibility)%size_divisibility) // 2))
return result
def cropfromsizedivisible(img, size_divisibility, orig_size):
height, width = img.shape[:2]
owidth, oheight = orig_size
result = img[(height-oheight)//2:oheight+(height-oheight)//2, (width-owidth)//2:owidth+(width-owidth)//2]
return result
def evaluate_image(image_path):
binary_threshold = 0.5
metadata = MetadataCatalog.get("__unused")
image_pil = Image.open(image_path).convert('RGB')
image_pil.thumbnail((384, 384))
osize = image_pil.size
if cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY > 0:
image_pil = expand2sizedivisible(image_pil, 0, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY)
image = np.asarray(image_pil)
image_pt = torch.from_numpy(np.array(image)).permute(2,0,1)
with torch.no_grad():
sample = [{'rgb': image_pt}]
preds = model.forward_base(sample, keys=['rgb'], get_eval=True)
masks_raw = torch.stack([x['sem_seg'] for x in preds], 0)
K = masks_raw.shape[1]
if K > 2:
masks_softmaxed = torch.softmax(masks_raw, dim=1)
masks_dict = merger(sample, masks_softmaxed)
K = 2
masks = masks_dict['cos']
else:
print(K)
masks = masks_raw.softmax(1)
masks_raw = F.interpolate(masks, size=(image_pt.shape[-2], image_pt.shape[-1]), mode='bilinear') # t s 1 h w
bg = edgeness(masks_raw)[0].argmax().item()
masks = masks_raw[0] > binary_threshold
frame_visualizer = Visualizer(image, metadata)
out = frame_visualizer.overlay_instances(
masks=masks[[int(bg==0)]],
alpha=0.3,
assigned_colors=[(1,0,1)]
).get_image()
return cropfromsizedivisible(out, cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY, osize)
paths = sorted(pathlib.Path('samples').glob('*.jpg'))
css = ".component-1 {height: 256px !important;}"
model, cfg = load_model_cfg("DAVIS")
merger = MaskMerger(cfg, model, merger_model="dino_vitb8")
demo = gr.Interface(
fn=evaluate_image,
inputs=gr.Image(label='Image', type='filepath'),
outputs=gr.Image(label='Annotated Image', type='numpy'),
examples=[[path.as_posix(), 0.15, 6] for path in paths],
title="Guess What Moves",
description="#### Unsupervised image segmentation mode of [Guess What Moves](https://www.robots.ox.ac.uk/~vgg/research/gwm/)",
css=css)
demo.queue().launch()