Spaces:
Runtime error
Runtime error
File size: 3,659 Bytes
d4ab5ac 9d9aad0 d4ab5ac 9d9aad0 d4ab5ac 9d9aad0 d4ab5ac 9d9aad0 d4ab5ac 9d9aad0 d4ab5ac 8fd2935 9d9aad0 d4ab5ac 0422277 8fd2935 523f190 d4ab5ac 0028cfc 8fd2935 0028cfc d4ab5ac 56ec0e7 d4ab5ac 0028cfc d4ab5ac 5dc90b6 0028cfc 5dc90b6 0028cfc 5dc90b6 d4ab5ac 0028cfc 523f190 0028cfc 523f190 d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import sys
sys.path.insert(0, './code')
from datamodules.transformations import UnNest
from models.interpretation import ImageInterpretationNet
from transformers import ViTFeatureExtractor, ViTForImageClassification
from utils.plot import smoothen, draw_mask_on_image, draw_heatmap_on_image
import gradio as gr
import numpy as np
import torch
# Load Vision Transformer
hf_model = "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10"
hf_model_imagenet = "google/vit-base-patch16-224"
vit = ViTForImageClassification.from_pretrained(hf_model)
vit_imagenet = ViTForImageClassification.from_pretrained(hf_model_imagenet)
vit.eval()
vit_imagenet.eval()
# Load Feature Extractor
feature_extractor = ViTFeatureExtractor.from_pretrained(hf_model, return_tensors="pt")
feature_extractor_imagenet = ViTFeatureExtractor.from_pretrained(hf_model_imagenet, return_tensors="pt")
feature_extractor = UnNest(feature_extractor)
feature_extractor_imagenet = UnNest(feature_extractor_imagenet)
# Load Vision DiffMask
diffmask = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask.ckpt')
diffmask.set_vision_transformer(vit)
diffmask_imagenet = ImageInterpretationNet.load_from_checkpoint('checkpoints/diffmask_imagenet.ckpt')
diffmask_imagenet.set_vision_transformer(vit_imagenet)
diffmask.eval()
diffmask_imagenet.eval()
# Define mask plotting functions
def draw_mask(image, mask):
return draw_mask_on_image(image, smoothen(mask))\
.permute(1, 2, 0)\
.clip(0, 1)\
.numpy()
def draw_heatmap(image, mask):
return draw_heatmap_on_image(image, smoothen(mask))\
.permute(1, 2, 0)\
.clip(0, 1)\
.numpy()
# Define callable method for the demo
@torch.no_grad()
def get_mask(image, model_name: str):
torch.manual_seed(seed=0)
if image is None:
return None, None, None
if model_name == 'DiffMask-CIFAR-10':
diffmask_model = diffmask
elif model_name == 'DiffMask-ImageNet':
diffmask_model = diffmask_imagenet
# Helper function to convert class index to name
def idx2cname(idx):
return diffmask_model.model.config.id2label[idx]
# Prepare image and pass through Vision DiffMask
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
dm_image = feature_extractor(image).unsqueeze(0)
dm_out = diffmask_model.get_mask(dm_image)
# Get mask and apply on image
mask = dm_out["mask"][0].detach()
masked_img = draw_mask(image, mask)
heatmap = draw_heatmap(image, mask)
# Get logits and map to predictions with class names
n_classes = len(diffmask_model.model.config.id2label)
logits_orig = dm_out["logits_orig"][0].detach().softmax(dim=-1)
logits_mask = dm_out["logits"][0].detach().softmax(dim=-1)
orig_probs = {idx2cname(i): logits_orig[i].item() for i in range(n_classes)}
mask_probs = {idx2cname(i): logits_mask[i].item() for i in range(n_classes)}
return np.hstack((masked_img, heatmap)), orig_probs, mask_probs
# Launch demo interface
gr.Interface(
get_mask,
inputs=[
gr.inputs.Image(label="Input", shape=(224, 224), source="upload", type="numpy"),
gr.inputs.Dropdown(label="Model Name", choices=["DiffMask-ImageNet", "DiffMask-CIFAR-10"]),
],
outputs=[
gr.outputs.Image(label="Output"),
gr.outputs.Label(label="Original Prediction", num_top_classes=5),
gr.outputs.Label(label="Masked Prediction", num_top_classes=5),
],
examples=[["dogcat.jpeg", "DiffMask-ImageNet"], ["elephant-zebra.jpg", "DiffMask-ImageNet"],
["finch.jpeg", "DiffMask-ImageNet"]],
title="Vision DiffMask Demo",
live=True,
).launch()
|