File size: 4,538 Bytes
93a1776
 
 
 
 
 
 
 
 
 
601fb13
93a1776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601fb13
93a1776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50f9a30
 
 
 
 
 
 
 
93a1776
 
fa6b4c8
c249249
93a1776
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import requests
import numpy as np
import cv2 as cv2
from PIL import Image

import torch
import torch.nn.functional as F
import open_clip

import gradio as gr
import spaces

from legrad import LeWrapper, LePreprocess


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer_index = -2  # will run on cpu
image_size = 448
# ---------- Init CLIP Model ----------
model_name = 'ViT-B-16'
pretrained = 'laion2b_s34b_b88k'
patch_size = 16

model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, device=device)
tokenizer = open_clip.get_tokenizer(model_name)

# ---------- Apply LeGrad's wrappers ----------
model = LeWrapper(model)
preprocess = LePreprocess(preprocess=preprocess, image_size=image_size)


# ---------- Function to load image from URL ----------
def change_to_url(url):
    img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
    return img_pil


def _get_text_embedding(model, tokenizer, classes: list, device):
    prompts = [f'a photo of a {cls}.' for cls in classes]

    tokenized_prompts = tokenizer(prompts).to(device)

    text_embedding = model.encode_text(tokenized_prompts)
    text_embedding = F.normalize(text_embedding, dim=-1)
    return text_embedding.unsqueeze(0)

# ---------- Function to convert logits to heatmaps ----------
def logits_to_heatmaps(logits, image_cv):
    logits = logits[0, 0].detach().cpu().numpy()
    logits = (logits * 255).astype('uint8')
    heat_map = cv2.applyColorMap(logits, cv2.COLORMAP_JET)
    viz = 0.4 * image_cv + 0.6 * heat_map
    viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
    return viz


# ---------- Main visualization function ----------
@spaces.GPU
def viz_func(url, image, text_query):
    image_torch = preprocess(image).unsqueeze(0).to(device)
    text_emb = _get_text_embedding(model, tokenizer, classes=[text_query], device=device)

    # ------- Get LeGrad output -------
    logits_legrad = model.compute_legrad(image=image_torch, text_embedding=text_emb)
    # ------- Get Heatmpas -------
    image_cv = cv2.cvtColor(np.array(image.resize((image_size, image_size))), cv2.COLOR_RGB2BGR)

    viz_legrad = logits_to_heatmaps(logits=logits_legrad, image_cv=image_cv)
    return viz_legrad

inputs = [
    gr.Textbox(label="Paste the url to the  selected image"),
    gr.Image(type="pil", interactive=True, label='Select An Image'),
    gr.Textbox(label="Text query"),
    ]


with gr.Blocks(css="#gradio-app-title { text-align: center; }") as demo:
    gr.Markdown(
        """
        # **LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity**
        ### This demo that showcases LeGrad method to visualize the important regions in an image that correspond to a given text query.
        The model used is OpenCLIP-ViT-B-16 (weights: `laion2b_s34b_b88k`)
        """
    )
    with gr.Row():
        with gr.Column():
            gr.Markdown('# Select An Image')
            selected_image = gr.Image(type="pil", interactive=True, label='')
            gr.Markdown('## Paste the url to the  selected image')
            url_query = gr.Textbox(label="")
            gr.Markdown('# Create your Own query')
            text_query = gr.Textbox(label='')
            run_button = gr.Button(icon='https://cdn-icons-png.flaticon.com/512/3348/3348036.png')

            inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])
            gr.Markdown('## LeGrad Explanation')
            le_grad_output = gr.Image(label='LeGrad')

            run_button.click(fn=viz_func,
                inputs=[url_query, selected_image, text_query],
                outputs=[le_grad_output])

        with gr.Column():
            gr.Markdown('# Select a Premade Example')
            gr.Examples(
                examples=[
                    ["assets/cats_remote_control.jpeg", "cat"],
                    ["assets/cats_remote_control.jpeg", "remote control"],
                    ["assets/la_baguette.webp", "la baguette"],
                    ["assets/la_baguette.webp", "beret"],
                    ["assets/pokemons.jpeg", "Pikachu"],
                    ["assets/pokemons.jpeg", "Bulbasaur"],
                    ["assets/pokemons.jpeg", "Charmander"],
                    ["assets/pokemons.jpeg", "Pokemons"],
                ],
                inputs=[selected_image, text_query],
                outputs=[le_grad_output],
                fn=viz_func,
                label=''
            )

demo.queue()
demo.launch()