WalidBouss commited on
Commit
93a1776
·
verified ·
1 Parent(s): bd005d0

Initial commit :tada:

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ import cv2 as cv2
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import open_clip
9
+
10
+ import gradio as gr
11
+
12
+ from legrad import LeWrapper, LePreprocess
13
+
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ layer_index = -2 # will run on cpu
17
+ image_size = 448
18
+ # ---------- Init CLIP Model ----------
19
+ model_name = 'ViT-B-16'
20
+ pretrained = 'laion2b_s34b_b88k'
21
+ patch_size = 16
22
+
23
+ model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, device=device)
24
+ tokenizer = open_clip.get_tokenizer(model_name)
25
+
26
+ # ---------- Apply LeGrad's wrappers ----------
27
+ model = LeWrapper(model)
28
+ preprocess = LePreprocess(preprocess=preprocess, image_size=image_size)
29
+
30
+
31
+ # ---------- Function to load image from URL ----------
32
+ def change_to_url(url):
33
+ img_pil = Image.open(requests.get(url, stream=True).raw).convert('RGB')
34
+ return img_pil
35
+
36
+
37
+ def _get_text_embedding(model, tokenizer, classes: list, device):
38
+ prompts = [f'a photo of a {cls}.' for cls in classes]
39
+
40
+ tokenized_prompts = tokenizer(prompts).to(device)
41
+
42
+ text_embedding = model.encode_text(tokenized_prompts)
43
+ text_embedding = F.normalize(text_embedding, dim=-1)
44
+ return text_embedding.unsqueeze(0)
45
+
46
+ # ---------- Function to convert logits to heatmaps ----------
47
+ def logits_to_heatmaps(logits, image_cv):
48
+ logits = logits[0, 0].detach().cpu().numpy()
49
+ logits = (logits * 255).astype('uint8')
50
+ heat_map = cv2.applyColorMap(logits, cv2.COLORMAP_JET)
51
+ viz = 0.4 * image_cv + 0.6 * heat_map
52
+ viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB)
53
+ return viz
54
+
55
+
56
+ # ---------- Main visualization function ----------
57
+ def viz_func(url, image, text_query):
58
+ image_torch = preprocess(image).unsqueeze(0).to(device)
59
+ text_emb = _get_text_embedding(model, tokenizer, classes=[text_query], device=device)
60
+
61
+ # ------- Get LeGrad output -------
62
+ logits_legrad = model.compute_legrad(image=image_torch, text_embedding=text_emb)
63
+ # ------- Get Heatmpas -------
64
+ image_cv = cv2.cvtColor(np.array(image.resize((image_size, image_size))), cv2.COLOR_RGB2BGR)
65
+
66
+ viz_legrad = logits_to_heatmaps(logits=logits_legrad, image_cv=image_cv)
67
+ return viz_legrad
68
+
69
+ inputs = [
70
+ gr.Textbox(label="Paste the url to the selected image"),
71
+ gr.Image(type="pil", interactive=True, label='Select An Image'),
72
+ gr.Textbox(label="Text query"),
73
+ ]
74
+
75
+
76
+ with gr.Blocks(css="#gradio-app-title { text-align: center; }") as demo:
77
+ gr.Markdown(
78
+ """
79
+ # **LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity**
80
+ ### This demo that showcases LeGrad method to visualize the important regions in an image that correspond to a given text query.
81
+ The model used is OpenCLIP-ViT-B-16 (weights: `laion2b_s34b_b88k`)
82
+ """
83
+ )
84
+ with gr.Row():
85
+ with gr.Column():
86
+ gr.Markdown('# Select An Image')
87
+ selected_image = gr.Image(type="pil", interactive=True, label='')
88
+ gr.Markdown('## Paste the url to the selected image')
89
+ url_query = gr.Textbox(label="")
90
+ gr.Markdown('# Create your Own query')
91
+ text_query = gr.Textbox(label='')
92
+ run_button = gr.Button(icon='https://cdn-icons-png.flaticon.com/512/3348/3348036.png')
93
+
94
+ inputs[0].change(fn=change_to_url, outputs=inputs[1], inputs=inputs[0])
95
+ gr.Markdown('## LeGrad Explanation')
96
+ le_grad_output = gr.Image(label='LeGrad')
97
+
98
+ run_button.click(fn=viz_func,
99
+ inputs=[url_query, selected_image, text_query],
100
+ outputs=[le_grad_output])
101
+
102
+ with gr.Column():
103
+ gr.Markdown('# Select a Premade Example')
104
+ gr.Examples(
105
+ examples=[
106
+ ["gradio_app/assets/cats_remote_control.jpeg", "cat"],
107
+ ["gradio_app/assets/cats_remote_control.jpeg", "remote control"],
108
+ ["gradio_app/assets/la_baguette.webp", "la baguette"],
109
+ ["gradio_app/assets/la_baguette.webp", "beret"],
110
+ ["gradio_app/assets/pokemons.jpeg", "Pikachu"],
111
+ ["gradio_app/assets/pokemons.jpeg", "Bulbasaur"],
112
+ ["gradio_app/assets/pokemons.jpeg", "Charmander"],
113
+ ["gradio_app/assets/pokemons.jpeg", "Pokemons"],
114
+ ],
115
+ inputs=[selected_image, text_query],
116
+ label=''
117
+ )
118
+
119
+ demo.queue()
120
+ demo.launch()