vision-diffmask / app.py
Orpheous1
no_grad
0422277
import sys
sys.path.insert(0, './code')
from datamodules.transformations import UnNest
from models.interpretation import ImageInterpretationNet
from transformers import ViTFeatureExtractor, ViTForImageClassification
from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
import gradio as gr
import numpy as np
import torch
# Load Vision Transformer
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
hf_model_imagenet = "google/vit-base-patch16-224"
vit = ViTForImageClassification.from_pretrained(hf_model)
vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet)
vit.eval()
vit_imagenet.eval()
# Load Feature Extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt")
feature_extractor = UnNest(feature_extractor)
feature_extractor_imagenet = UnNest(feature_extractor_imagenet)
# Load Vision DiffMask
diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
diffmask.set_vision_transformer(vit)
diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
diffmask_imagenet.set_vision_transformer(vit_imagenet)
diffmask.eval()
diffmask_imagenet.eval()
# Define mask plotting functions
def draw_mask(image, mask):
return draw_mask_on_image(image, smoothen(mask))\
.permute(1, 2, 0)\
.clip(0, 1)\
.numpy()
def draw_heatmap(image, mask):
return draw_heatmap_on_image(image, smoothen(mask))\
.permute(1, 2, 0)\
.clip(0, 1)\
.numpy()
# Define callable method for the demo
@torch.no_grad()
def get_mask(image, model_name: str):
torch.manual_seed(seed=0)
if image is None:
return None, None, None
if model_name == 'DiffMask-CIFAR-10':
diffmask_model = diffmask
elif model_name == 'DiffMask-ImageNet':
diffmask_model = diffmask_imagenet
# Helper function to convert class index to name
def idx2cname(idx):
return diffmask_model.model.config.id2label[idx]
# Prepare image and pass through Vision DiffMask
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
dm_image = feature_extractor(image).unsqueeze(0)
dm_out = diffmask_model.get_mask(dm_image)
# Get mask and apply on image
mask = dm_out["mask"][0].detach()
masked_img = draw_mask(image, mask)
heatmap = draw_heatmap(image, mask)
# Get logits and map to predictions with class names
n_classes = len(diffmask_model.model.config.id2label)
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
logits_mask = dm_out["logits"][0].detach().softmax(dim=-1)
orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)}
mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)}
return np.hstack((masked_img, heatmap)), orig_probs, mask_probs
# Launch demo interface
gr.Interface(
get_mask,
inputs=[
gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-ImageNet", "DiffMask-CIFAR-10"]),
],
outputs=[
gr.outputs.Image(label="Output"),
gr.outputs.Label(label="Original Prediction", num_top_classes=5),
gr.outputs.Label(label="Masked Prediction", num_top_classes=5),
],
examples=[["dogcat.jpeg", "DiffMask-ImageNet"], ["elephant-zebra.jpg", "DiffMask-ImageNet"],
["finch.jpeg", "DiffMask-ImageNet"]],
title="Vision DiffMask Demo",
live=True,
).launch()