File size: 2,136 Bytes
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1958201
 
 
c4b2b37
7d808f1
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e6f824
c4b2b37
 
 
 
 
 
 
 
 
 
26b53b0
c4b2b37
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from PIL import Image
from torchvision import transforms

from gradcam import do_gradcam
from lrp import do_lrp, do_partial_lrp
from rollout import do_rollout
from tiba import do_tiba

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
TRANSFORM = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]
)
METHOD_MAP = {
    "tiba": do_tiba,
    "gradcam": do_gradcam,
    "lrp": do_lrp,
    "partial_lrp": do_partial_lrp,
    "rollout": do_rollout,
}


def generate_viz(image, method, class_index=None):
    print(f"Image: {image.size}")
    print(f"Method: {method}")
    print(f"Class: {class_index}")
    viz_method = METHOD_MAP[method]
    viz = viz_method(TRANSFORM, image, class_index=class_index)
    viz.savefig("visualization.png")
    return Image.open("visualization.png").convert("RGB")


title = "Compare different methods of explaining ViTs 🤖"
article = "Different methods for explaining Vision Transformers as explored by Chefer et al. in [Transformer Interpretability Beyond Attention Visualization, a novel method to visualize classifications by Transformer based networks](https://arxiv.org/abs/2012.09838)."

iface = gr.Interface(
    generate_viz,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(
            list(METHOD_MAP.keys()),
            label="Method",
            info="Explainability method to investigate.",
        ),
        gr.Number(label="Class Index", info="Class index to inspect"),
    ],
    outputs=gr.Image(),
    title=title,
    article=article,
    allow_flagging="never",
    cache_examples=True,
    examples=[
        ["Transformer-Explainability/samples/catdog.png", "tiba", None],
        ["Transformer-Explainability/samples/catdog.png", "rollout", 243],
        ["Transformer-Explainability/samples/el2.png", "tiba", None],
        # ["Transformer-Explainability/samples/el2.png", "gradcam", 340],
        ["Transformer-Explainability/samples/dogbird.png", "lrp", 161],
    ],
)
iface.launch(debug=True)