Spaces:
Runtime error
Runtime error
import sys, os | |
import torch | |
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2]) | |
CUDA_VERSION = torch.__version__.split("+")[-1] | |
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION) | |
# Install detectron2 that matches the above pytorch version | |
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions | |
os.system(f'pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/{CUDA_VERSION}/torch{TORCH_VERSION}/index.html') | |
os.system("pip install git+https://github.com/cocodataset/panopticapi.git") | |
# Imports | |
import gradio as gr | |
import detectron2 | |
from detectron2.utils.logger import setup_logger | |
import numpy as np | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms.functional as TF | |
from torchvision import datasets, transforms | |
from einops import rearrange | |
from PIL import Image | |
import imutils | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import ImageGrid | |
from tqdm import tqdm | |
import random | |
from functools import partial | |
import time | |
# import some common detectron2 utilities | |
from detectron2 import model_zoo | |
from detectron2.engine import DefaultPredictor | |
from detectron2.config import get_cfg | |
from detectron2.utils.visualizer import Visualizer, ColorMode | |
from detectron2.data import MetadataCatalog | |
from detectron2.projects.deeplab import add_deeplab_config | |
coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
# Import Mask2Former | |
from mask2former import add_maskformer2_config | |
# DPT dependencies for depth pseudo labeling | |
from dpt.models import DPTDepthModel | |
from multimae.input_adapters import PatchedInputAdapter, SemSegInputAdapter | |
from multimae.output_adapters import SpatialOutputAdapter | |
from multimae.multimae import pretrain_multimae_base | |
from utils.data_constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
torch.set_grad_enabled(False) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f'device: {device}') | |
# Initialize COCO Mask2Former | |
cfg = get_cfg() | |
cfg.MODEL.DEVICE='cpu' | |
add_deeplab_config(cfg) | |
add_maskformer2_config(cfg) | |
cfg.merge_from_file("mask2former/configs/coco/panoptic-segmentation/swin/maskformer2_swin_small_bs16_50ep.yaml") | |
cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/maskformer/mask2former/coco/panoptic/maskformer2_swin_small_bs16_50ep/model_final_a407fd.pkl' | |
cfg.MODEL.MASK_FORMER.TEST.SEMANTIC_ON = True | |
cfg.MODEL.MASK_FORMER.TEST.INSTANCE_ON = True | |
cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = True | |
semseg_model = DefaultPredictor(cfg) | |
def predict_semseg(img): | |
return semseg_model(255*img.permute(1,2,0).numpy())['sem_seg'].argmax(0) | |
def plot_semseg(img, semseg, ax): | |
v = Visualizer(img.permute(1,2,0), coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW) | |
semantic_result = v.draw_sem_seg(semseg.cpu()).get_image() | |
ax.imshow(semantic_result) | |
# Initialize Omnidata depth model | |
os.system("wget https://drive.switch.ch/index.php/s/RFfTZwyKROKKx0l/download") | |
os.system("unzip -j download -d pretrained_models") | |
os.system("rm download") | |
omnidata_ckpt = torch.load('./pretrained_models/omnidata_rgb2depth_dpt_hybrid.pth', map_location='cpu') | |
depth_model = DPTDepthModel() | |
depth_model.load_state_dict(omnidata_ckpt) | |
depth_model = depth_model.to(device).eval() | |
def predict_depth(img): | |
depth_model_input = (img.unsqueeze(0) - 0.5) / 0.5 | |
return depth_model(depth_model_input.to(device)) | |
# MultiMAE model setup | |
DOMAIN_CONF = { | |
'rgb': { | |
'input_adapter': partial(PatchedInputAdapter, num_channels=3, stride_level=1), | |
'output_adapter': partial(SpatialOutputAdapter, num_channels=3, stride_level=1), | |
}, | |
'depth': { | |
'input_adapter': partial(PatchedInputAdapter, num_channels=1, stride_level=1), | |
'output_adapter': partial(SpatialOutputAdapter, num_channels=1, stride_level=1), | |
}, | |
'semseg': { | |
'input_adapter': partial(SemSegInputAdapter, num_classes=133, | |
dim_class_emb=64, interpolate_class_emb=False, stride_level=4), | |
'output_adapter': partial(SpatialOutputAdapter, num_channels=133, stride_level=4), | |
}, | |
} | |
DOMAINS = ['rgb', 'depth', 'semseg'] | |
input_adapters = { | |
domain: dinfo['input_adapter']( | |
patch_size_full=16, | |
) | |
for domain, dinfo in DOMAIN_CONF.items() | |
} | |
output_adapters = { | |
domain: dinfo['output_adapter']( | |
patch_size_full=16, | |
dim_tokens=256, | |
use_task_queries=True, | |
depth=2, | |
context_tasks=DOMAINS, | |
task=domain | |
) | |
for domain, dinfo in DOMAIN_CONF.items() | |
} | |
multimae = pretrain_multimae_base( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
) | |
CKPT_URL = 'https://github.com/EPFL-VILAB/MultiMAE/releases/download/pretrained-weights/multimae-b_98_rgb+-depth-semseg_1600e_multivit-afff3f8c.pth' | |
ckpt = torch.hub.load_state_dict_from_url(CKPT_URL, map_location='cpu') | |
multimae.load_state_dict(ckpt['model'], strict=False) | |
multimae = multimae.to(device).eval() | |
# Plotting | |
def get_masked_image(img, mask, image_size=224, patch_size=16, mask_value=0.0): | |
img_token = rearrange( | |
img.detach().cpu(), | |
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
) | |
img_token[mask.detach().cpu()!=0] = mask_value | |
img = rearrange( | |
img_token, | |
'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', | |
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
) | |
return img | |
def denormalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): | |
return TF.normalize( | |
img.clone(), | |
mean= [-m/s for m, s in zip(mean, std)], | |
std= [1/s for s in std] | |
) | |
def plot_semseg_gt(input_dict, ax=None, image_size=224): | |
metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
instance_mode = ColorMode.IMAGE | |
img_viz = 255 * denormalize(input_dict['rgb'].detach().cpu())[0].permute(1,2,0) | |
semseg = F.interpolate( | |
input_dict['semseg'].unsqueeze(0).cpu().float(), size=image_size, mode='nearest' | |
).long()[0,0] | |
visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1) | |
visualizer.draw_sem_seg(semseg) | |
if ax is not None: | |
ax.imshow(visualizer.get_output().get_image()) | |
else: | |
return visualizer.get_output().get_image() | |
def plot_semseg_gt_masked(input_dict, mask, ax=None, mask_value=1.0, image_size=224): | |
img = plot_semseg_gt(input_dict, image_size=image_size) | |
img = torch.LongTensor(img).permute(2,0,1).unsqueeze(0) | |
masked_img = get_masked_image(img.float()/255.0, mask, image_size=image_size, patch_size=16, mask_value=mask_value) | |
masked_img = masked_img[0].permute(1,2,0) | |
if ax is not None: | |
ax.imshow(masked_img) | |
else: | |
return masked_img | |
def get_pred_with_input(gt, pred, mask, image_size=224, patch_size=16): | |
gt_token = rearrange( | |
gt.detach().cpu(), | |
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
) | |
pred_token = rearrange( | |
pred.detach().cpu(), | |
'b c (nh ph) (nw pw) -> b (nh nw) (c ph pw)', | |
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
) | |
pred_token[mask.detach().cpu()==0] = gt_token[mask.detach().cpu()==0] | |
img = rearrange( | |
pred_token, | |
'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)', | |
ph=patch_size, pw=patch_size, nh=image_size//patch_size, nw=image_size//patch_size | |
) | |
return img | |
def plot_semseg_pred_masked(rgb, semseg_preds, semseg_gt, mask, ax=None, image_size=224): | |
metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
instance_mode = ColorMode.IMAGE | |
img_viz = 255 * denormalize(rgb.detach().cpu())[0].permute(1,2,0) | |
semseg = get_pred_with_input( | |
semseg_gt.unsqueeze(1), | |
semseg_preds.argmax(1).unsqueeze(1), | |
mask, | |
image_size=image_size//4, | |
patch_size=4 | |
) | |
semseg = F.interpolate(semseg.float(), size=image_size, mode='nearest')[0,0].long() | |
visualizer = Visualizer(img_viz, metadata, instance_mode=instance_mode, scale=1) | |
visualizer.draw_sem_seg(semseg) | |
if ax is not None: | |
ax.imshow(visualizer.get_output().get_image()) | |
else: | |
return visualizer.get_output().get_image() | |
def plot_predictions(input_dict, preds, masks, image_size=224): | |
masked_rgb = get_masked_image( | |
denormalize(input_dict['rgb']), | |
masks['rgb'], | |
image_size=image_size, | |
mask_value=1.0 | |
)[0].permute(1,2,0).detach().cpu() | |
masked_depth = get_masked_image( | |
input_dict['depth'], | |
masks['depth'], | |
image_size=image_size, | |
mask_value=np.nan | |
)[0,0].detach().cpu() | |
pred_rgb = denormalize(preds['rgb'])[0].permute(1,2,0).clamp(0,1) | |
pred_depth = preds['depth'][0,0].detach().cpu() | |
pred_rgb2 = get_pred_with_input( | |
denormalize(input_dict['rgb']), | |
denormalize(preds['rgb']).clamp(0,1), | |
masks['rgb'], | |
image_size=image_size | |
)[0].permute(1,2,0).detach().cpu() | |
pred_depth2 = get_pred_with_input( | |
input_dict['depth'], | |
preds['depth'], | |
masks['depth'], | |
image_size=image_size | |
)[0,0].detach().cpu() | |
fig = plt.figure(figsize=(10, 10)) | |
grid = ImageGrid(fig, 111, nrows_ncols=(3, 3), axes_pad=0) | |
grid[0].imshow(masked_rgb) | |
grid[1].imshow(pred_rgb2) | |
grid[2].imshow(denormalize(input_dict['rgb'])[0].permute(1,2,0).detach().cpu()) | |
grid[3].imshow(masked_depth) | |
grid[4].imshow(pred_depth2) | |
grid[5].imshow(input_dict['depth'][0,0].detach().cpu()) | |
plot_semseg_gt_masked(input_dict, masks['semseg'], grid[6], mask_value=1.0, image_size=image_size) | |
plot_semseg_pred_masked(input_dict['rgb'], preds['semseg'], input_dict['semseg'], masks['semseg'], grid[7], image_size=image_size) | |
plot_semseg_gt(input_dict, grid[8], image_size=image_size) | |
for ax in grid: | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
fontsize = 16 | |
grid[0].set_title('Masked inputs', fontsize=fontsize) | |
grid[1].set_title('MultiMAE predictions', fontsize=fontsize) | |
grid[2].set_title('Original Reference', fontsize=fontsize) | |
grid[0].set_ylabel('RGB', fontsize=fontsize) | |
grid[3].set_ylabel('Depth', fontsize=fontsize) | |
grid[6].set_ylabel('Semantic', fontsize=fontsize) | |
plt.savefig('./output.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
def inference(img, num_tokens, manual_mode, num_rgb, num_depth, num_semseg, seed): | |
num_tokens = int(588 * num_tokens / 100.0) | |
num_rgb = int(196 * num_rgb / 100.0) | |
num_depth = int(196 * num_depth / 100.0) | |
num_semseg = int(196 * num_semseg / 100.0) | |
im = Image.open(img) | |
# Center crop and resize RGB | |
image_size = 224 # Train resolution | |
img = TF.center_crop(TF.to_tensor(im), min(im.size)) | |
img = TF.resize(img, image_size, interpolation=TF.InterpolationMode.BICUBIC) | |
# Predict depth and semseg | |
depth = predict_depth(img) | |
semseg = predict_semseg(img) | |
# Pre-process RGB, depth and semseg to the MultiMAE input format | |
input_dict = {} | |
# Normalize RGB | |
input_dict['rgb'] = TF.normalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD).unsqueeze(0) | |
# Normalize depth robustly | |
trunc_depth = torch.sort(depth.flatten())[0] | |
trunc_depth = trunc_depth[int(0.1 * trunc_depth.shape[0]): int(0.9 * trunc_depth.shape[0])] | |
depth = (depth - trunc_depth.mean()[None,None,None]) / torch.sqrt(trunc_depth.var()[None,None,None] + 1e-6) | |
input_dict['depth'] = depth.unsqueeze(0) | |
# Downsample semantic segmentation | |
stride = 4 | |
semseg = TF.resize(semseg.unsqueeze(0), (semseg.shape[0] // stride, semseg.shape[1] // stride), interpolation=TF.InterpolationMode.NEAREST) | |
input_dict['semseg'] = semseg | |
# To GPU | |
input_dict = {k: v.to(device) for k,v in input_dict.items()} | |
if not manual_mode: | |
# Randomly sample masks | |
torch.manual_seed(int(time.time())) # Random mode is random | |
preds, masks = multimae.forward( | |
input_dict, | |
mask_inputs=True, # True if forward pass should sample random masks | |
num_encoded_tokens=num_tokens, | |
alphas=1.0 | |
) | |
else: | |
# Randomly sample masks using the specified number of tokens per modality | |
torch.manual_seed(int(seed)) # change seed to resample new mask | |
task_masks = {domain: torch.ones(1,196).long().to(device) for domain in DOMAINS} | |
selected_rgb_idxs = torch.randperm(196)[:num_rgb] | |
selected_depth_idxs = torch.randperm(196)[:num_depth] | |
selected_semseg_idxs = torch.randperm(196)[:num_semseg] | |
task_masks['rgb'][:,selected_rgb_idxs] = 0 | |
task_masks['depth'][:,selected_depth_idxs] = 0 | |
task_masks['semseg'][:,selected_semseg_idxs] = 0 | |
preds, masks = multimae.forward( | |
input_dict, | |
mask_inputs=True, | |
task_masks=task_masks | |
) | |
preds = {domain: pred.detach().cpu() for domain, pred in preds.items()} | |
masks = {domain: mask.detach().cpu() for domain, mask in masks.items()} | |
plot_predictions(input_dict, preds, masks) | |
return 'output.png' | |
title = "MultiMAE" | |
description = "Gradio demo for MultiMAE: Multi-modal Multi-task Masked Autoencoders. \ | |
Upload your own images or try one of the examples below to explore the multi-modal masked reconstruction of a pre-trained MultiMAE model. \ | |
Uploaded images are pseudo labeled using a DPT trained on Omnidata depth, and a Mask2Former trained on COCO. \ | |
Choose the percentage of visible tokens using the sliders below and see how MultiMAE reconstructs the modalities!" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.01678' \ | |
target='_blank'>MultiMAE: Multi-modal Multi-task Masked Autoencoders</a> | \ | |
<a href='https://github.com/EPFL-VILAB/MultiMAE' target='_blank'>Github Repo</a></p>" | |
css = '.output-image{height: 713px !important}' | |
# Example images | |
os.system("wget https://i.imgur.com/c9ObJdK.jpg") | |
os.system("wget https://i.imgur.com/KTKgYKi.jpg") | |
os.system("wget https://i.imgur.com/lWYuRI7.jpg") | |
examples = [ | |
['c9ObJdK.jpg', 15, False, 15, 15, 15, 0], | |
['KTKgYKi.jpg', 15, False, 15, 15, 15, 0], | |
['lWYuRI7.jpg', 15, False, 15, 15, 15, 0], | |
] | |
gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.inputs.Image(label='RGB input image', type='filepath'), | |
gr.inputs.Slider(label='Percentage of input tokens', default=15, step=0.1, minimum=0, maximum=100), | |
gr.inputs.Checkbox(label='Manual mode: Check this to manually set the number of input tokens per modality using the sliders below', default=False), | |
gr.inputs.Slider(label='Percentage of RGB input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
gr.inputs.Slider(label='Percentage of depth input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
gr.inputs.Slider(label='Percentage of semantic input tokens (for manual mode only)', default=15, step=0.1, minimum=0, maximum=100), | |
gr.inputs.Number(label='Random seed: Change this to sample different masks (for manual mode only)', default=0), | |
], | |
outputs=[ | |
gr.outputs.Image(label='MultiMAE predictions', type='file') | |
], | |
css=css, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples | |
).launch(enable_queue=True, cache_examples=False) | |