wenkai commited on
Commit
846a3aa
·
verified ·
1 Parent(s): c338275

Upload 12 files

Browse files
app/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from PIL import Image
9
+ import requests
10
+
11
+ import streamlit as st
12
+ import torch
13
+
14
+
15
+ @st.cache()
16
+ def load_demo_image():
17
+ img_url = (
18
+ "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
19
+ )
20
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
21
+ return raw_image
22
+
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ cache_root = "/export/home/.cache/lavis/"
app/calculate_coco_features.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from PIL import Image
9
+ import requests
10
+ import torch
11
+
12
+ import os
13
+
14
+ from lavis.common.registry import registry
15
+ from lavis.processors import *
16
+ from lavis.models import *
17
+ from lavis.common.utils import build_default_model
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ def load_demo_image():
23
+ img_url = (
24
+ "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
25
+ )
26
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
27
+
28
+ return raw_image
29
+
30
+
31
+ def read_img(filepath):
32
+ raw_image = Image.open(filepath).convert("RGB")
33
+
34
+ return raw_image
35
+
36
+
37
+ # model
38
+ model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
39
+ feature_extractor = BlipFeatureExtractor(pretrained=model_url)
40
+
41
+ feature_extractor.eval()
42
+ feature_extractor = feature_extractor.to(device)
43
+
44
+ # preprocessors
45
+ vis_processor = BlipImageEvalProcessor(image_size=224)
46
+ text_processor = BlipCaptionProcessor()
47
+
48
+ # files to process
49
+ # file_root = "/export/home/.cache/lavis/coco/images/val2014"
50
+ file_root = "/export/home/.cache/lavis/coco/images/train2014"
51
+ filepaths = os.listdir(file_root)
52
+
53
+ print(len(filepaths))
54
+
55
+ caption = "dummy"
56
+
57
+ path2feat = dict()
58
+ bsz = 256
59
+
60
+ images_in_batch = []
61
+ filepaths_in_batch = []
62
+
63
+ for i, filename in enumerate(filepaths):
64
+ if i % bsz == 0 and i > 0:
65
+ images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
66
+ with torch.no_grad():
67
+ image_features = feature_extractor(
68
+ images_in_batch, caption, mode="image", normalized=True
69
+ )[:, 0]
70
+
71
+ for filepath, image_feat in zip(filepaths_in_batch, image_features):
72
+ path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
73
+
74
+ images_in_batch = []
75
+ filepaths_in_batch = []
76
+
77
+ print(len(path2feat), image_features.shape)
78
+ else:
79
+ filepath = os.path.join(file_root, filename)
80
+
81
+ image = read_img(filepath)
82
+ image = vis_processor(image).unsqueeze(0)
83
+
84
+ images_in_batch.append(image)
85
+ filepaths_in_batch.append(filepath)
86
+
87
+ torch.save(path2feat, "path2feat_coco_train2014.pth")
app/caption.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import streamlit as st
9
+ from app import device, load_demo_image
10
+ from app.utils import load_model_cache
11
+ from lavis.processors import load_processor
12
+ from PIL import Image
13
+
14
+
15
+ def app():
16
+ # ===== layout =====
17
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
18
+
19
+ sampling_method = st.sidebar.selectbox(
20
+ "Sampling method:", ["Beam search", "Nucleus sampling"]
21
+ )
22
+
23
+ st.markdown(
24
+ "<h1 style='text-align: center;'>Image Description Generation</h1>",
25
+ unsafe_allow_html=True,
26
+ )
27
+
28
+ instructions = """Try the provided image or upload your own:"""
29
+ file = st.file_uploader(instructions)
30
+
31
+ use_beam = sampling_method == "Beam search"
32
+
33
+ col1, col2 = st.columns(2)
34
+
35
+ if file:
36
+ raw_img = Image.open(file).convert("RGB")
37
+ else:
38
+ raw_img = load_demo_image()
39
+
40
+ col1.header("Image")
41
+
42
+ w, h = raw_img.size
43
+ scaling_factor = 720 / w
44
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
45
+
46
+ col1.image(resized_image, use_column_width=True)
47
+ col2.header("Description")
48
+
49
+ cap_button = st.button("Generate")
50
+
51
+ # ==== event ====
52
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
53
+
54
+ if cap_button:
55
+ if model_type.startswith("BLIP"):
56
+ blip_type = model_type.split("_")[1].lower()
57
+ model = load_model_cache(
58
+ "blip_caption",
59
+ model_type=f"{blip_type}_coco",
60
+ is_eval=True,
61
+ device=device,
62
+ )
63
+
64
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
65
+ captions = generate_caption(
66
+ model=model, image=img, use_nucleus_sampling=not use_beam
67
+ )
68
+
69
+ col2.write("\n\n".join(captions), use_column_width=True)
70
+
71
+
72
+ def generate_caption(
73
+ model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
74
+ ):
75
+ samples = {"image": image}
76
+
77
+ captions = []
78
+ if use_nucleus_sampling:
79
+ for _ in range(5):
80
+ caption = model.generate(
81
+ samples,
82
+ use_nucleus_sampling=True,
83
+ max_length=max_length,
84
+ min_length=min_length,
85
+ top_p=0.9,
86
+ )
87
+ captions.append(caption[0])
88
+ else:
89
+ caption = model.generate(
90
+ samples,
91
+ use_nucleus_sampling=False,
92
+ num_beams=num_beams,
93
+ max_length=max_length,
94
+ min_length=min_length,
95
+ )
96
+ captions.append(caption[0])
97
+
98
+ return captions
app/classification.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import plotly.graph_objects as go
9
+ import requests
10
+ import streamlit as st
11
+ import torch
12
+ from lavis.models import load_model
13
+ from lavis.processors import load_processor
14
+ from lavis.processors.blip_processors import BlipCaptionProcessor
15
+ from PIL import Image
16
+
17
+ from app import device, load_demo_image
18
+ from app.utils import load_blip_itm_model
19
+ from lavis.processors.clip_processors import ClipImageEvalProcessor
20
+
21
+
22
+ @st.cache()
23
+ def load_demo_image(img_url=None):
24
+ if not img_url:
25
+ img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg"
26
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
27
+ return raw_image
28
+
29
+
30
+ @st.cache(
31
+ hash_funcs={
32
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
33
+ .cpu()
34
+ .numpy()
35
+ },
36
+ allow_output_mutation=True,
37
+ )
38
+ def load_model_cache(model_type, device):
39
+ if model_type == "blip":
40
+ model = load_model(
41
+ "blip_feature_extractor", model_type="base", is_eval=True, device=device
42
+ )
43
+ elif model_type == "albef":
44
+ model = load_model(
45
+ "albef_feature_extractor", model_type="base", is_eval=True, device=device
46
+ )
47
+ elif model_type == "CLIP_ViT-B-32":
48
+ model = load_model(
49
+ "clip_feature_extractor", "ViT-B-32", is_eval=True, device=device
50
+ )
51
+ elif model_type == "CLIP_ViT-B-16":
52
+ model = load_model(
53
+ "clip_feature_extractor", "ViT-B-16", is_eval=True, device=device
54
+ )
55
+ elif model_type == "CLIP_ViT-L-14":
56
+ model = load_model(
57
+ "clip_feature_extractor", "ViT-L-14", is_eval=True, device=device
58
+ )
59
+
60
+ return model
61
+
62
+
63
+ def app():
64
+ model_type = st.sidebar.selectbox(
65
+ "Model:",
66
+ ["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
67
+ )
68
+ score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
69
+
70
+ # ===== layout =====
71
+ st.markdown(
72
+ "<h1 style='text-align: center;'>Zero-shot Classification</h1>",
73
+ unsafe_allow_html=True,
74
+ )
75
+
76
+ instructions = """Try the provided image or upload your own:"""
77
+ file = st.file_uploader(instructions)
78
+
79
+ st.header("Image")
80
+ if file:
81
+ raw_img = Image.open(file).convert("RGB")
82
+ else:
83
+ raw_img = load_demo_image()
84
+
85
+ st.image(raw_img) # , use_column_width=True)
86
+
87
+ col1, col2 = st.columns(2)
88
+
89
+ col1.header("Categories")
90
+
91
+ cls_0 = col1.text_input("category 1", value="merlion")
92
+ cls_1 = col1.text_input("category 2", value="sky")
93
+ cls_2 = col1.text_input("category 3", value="giraffe")
94
+ cls_3 = col1.text_input("category 4", value="fountain")
95
+ cls_4 = col1.text_input("category 5", value="marina bay")
96
+
97
+ cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
98
+ cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
99
+
100
+ if len(cls_names) != len(set(cls_names)):
101
+ st.error("Please provide unique class names")
102
+ return
103
+
104
+ button = st.button("Submit")
105
+
106
+ col2.header("Prediction")
107
+
108
+ # ===== event =====
109
+
110
+ if button:
111
+ if model_type.startswith("BLIP"):
112
+ text_processor = BlipCaptionProcessor(prompt="A picture of ")
113
+ cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
114
+
115
+ if score_type == "Cosine":
116
+ vis_processor = load_processor("blip_image_eval").build(image_size=224)
117
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
118
+
119
+ feature_extractor = load_model_cache(model_type="blip", device=device)
120
+
121
+ sample = {"image": img, "text_input": cls_prompt}
122
+
123
+ with torch.no_grad():
124
+ image_features = feature_extractor.extract_features(
125
+ sample, mode="image"
126
+ ).image_embeds_proj[:, 0]
127
+ text_features = feature_extractor.extract_features(
128
+ sample, mode="text"
129
+ ).text_embeds_proj[:, 0]
130
+ sims = (image_features @ text_features.t())[
131
+ 0
132
+ ] / feature_extractor.temp
133
+
134
+ else:
135
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
136
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
137
+
138
+ model = load_blip_itm_model(device)
139
+
140
+ output = model(img, cls_prompt, match_head="itm")
141
+ sims = output[:, 1]
142
+
143
+ sims = torch.nn.Softmax(dim=0)(sims)
144
+ inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
145
+
146
+ elif model_type.startswith("ALBEF"):
147
+ vis_processor = load_processor("blip_image_eval").build(image_size=224)
148
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
149
+
150
+ text_processor = BlipCaptionProcessor(prompt="A picture of ")
151
+ cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
152
+
153
+ feature_extractor = load_model_cache(model_type="albef", device=device)
154
+
155
+ sample = {"image": img, "text_input": cls_prompt}
156
+
157
+ with torch.no_grad():
158
+ image_features = feature_extractor.extract_features(
159
+ sample, mode="image"
160
+ ).image_embeds_proj[:, 0]
161
+ text_features = feature_extractor.extract_features(
162
+ sample, mode="text"
163
+ ).text_embeds_proj[:, 0]
164
+
165
+ st.write(image_features.shape)
166
+ st.write(text_features.shape)
167
+
168
+ sims = (image_features @ text_features.t())[0] / feature_extractor.temp
169
+
170
+ sims = torch.nn.Softmax(dim=0)(sims)
171
+ inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
172
+
173
+ elif model_type.startswith("CLIP"):
174
+ if model_type == "CLIP_ViT-B-32":
175
+ model = load_model_cache(model_type="CLIP_ViT-B-32", device=device)
176
+ elif model_type == "CLIP_ViT-B-16":
177
+ model = load_model_cache(model_type="CLIP_ViT-B-16", device=device)
178
+ elif model_type == "CLIP_ViT-L-14":
179
+ model = load_model_cache(model_type="CLIP_ViT-L-14", device=device)
180
+ else:
181
+ raise ValueError(f"Unknown model type {model_type}")
182
+
183
+ if score_type == "Cosine":
184
+ # image_preprocess = ClipImageEvalProcessor(image_size=336)
185
+ image_preprocess = ClipImageEvalProcessor(image_size=224)
186
+ img = image_preprocess(raw_img).unsqueeze(0).to(device)
187
+
188
+ sample = {"image": img, "text_input": cls_names}
189
+
190
+ with torch.no_grad():
191
+ clip_features = model.extract_features(sample)
192
+
193
+ image_features = clip_features.image_embeds_proj
194
+ text_features = clip_features.text_embeds_proj
195
+
196
+ sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
197
+ inv_sims = sims.tolist()[::-1]
198
+ else:
199
+ st.warning("CLIP does not support multimodal scoring.")
200
+ return
201
+
202
+ fig = go.Figure(
203
+ go.Bar(
204
+ x=inv_sims,
205
+ y=cls_names[::-1],
206
+ text=["{:.2f}".format(s) for s in inv_sims],
207
+ orientation="h",
208
+ )
209
+ )
210
+ fig.update_traces(
211
+ textfont_size=12,
212
+ textangle=0,
213
+ textposition="outside",
214
+ cliponaxis=False,
215
+ )
216
+ col2.plotly_chart(fig, use_container_width=True)
app/dataset_browser.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import random
9
+ from collections import OrderedDict
10
+ from functools import reduce
11
+ from tkinter import N
12
+
13
+ import streamlit as st
14
+ from lavis.common.registry import registry
15
+ from lavis.datasets.builders import dataset_zoo, load_dataset
16
+ from lavis.datasets.builders.base_dataset_builder import load_dataset_config
17
+ from PIL import Image
18
+
19
+ IMAGE_LAYOUT = 3, 4
20
+ VIDEO_LAYOUT = 1, 2
21
+
22
+ PREV_STR = "Prev"
23
+ NEXT_STR = "Next"
24
+
25
+
26
+ def sample_dataset(dataset, indices):
27
+ samples = [dataset.displ_item(idx) for idx in indices]
28
+
29
+ return samples
30
+
31
+
32
+ def get_concat_v(im1, im2):
33
+ margin = 5
34
+
35
+ canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
36
+ canvas = Image.new("RGB", canvas_size, "White")
37
+ canvas.paste(im1, (0, 0))
38
+ canvas.paste(im2, (im1.width + margin, 0))
39
+
40
+ return canvas
41
+
42
+
43
+ def resize_img_w(raw_img, new_w=224):
44
+ if isinstance(raw_img, list):
45
+ resized_imgs = [resize_img_w(img, 196) for img in raw_img]
46
+ # concatenate images
47
+ resized_image = reduce(get_concat_v, resized_imgs)
48
+ else:
49
+ w, h = raw_img.size
50
+ scaling_factor = new_w / w
51
+ resized_image = raw_img.resize(
52
+ (int(w * scaling_factor), int(h * scaling_factor))
53
+ )
54
+
55
+ return resized_image
56
+
57
+
58
+ def get_visual_key(dataset):
59
+ if "image" in dataset[0]:
60
+ return "image"
61
+ elif "image0" in dataset[0]: # NLVR2 dataset
62
+ return "image"
63
+ elif "video" in dataset[0]:
64
+ return "video"
65
+ else:
66
+ raise ValueError("Visual key not found.")
67
+
68
+
69
+ def gather_items(samples, exclude=[]):
70
+ gathered = []
71
+
72
+ for s in samples:
73
+ ns = OrderedDict()
74
+ for k in s.keys():
75
+ if k not in exclude:
76
+ ns[k] = s[k]
77
+
78
+ gathered.append(ns)
79
+
80
+ return gathered
81
+
82
+
83
+ @st.cache(allow_output_mutation=True)
84
+ def load_dataset_cache(name):
85
+ return load_dataset(name)
86
+
87
+
88
+ def format_text(text):
89
+ md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
90
+
91
+ return md
92
+
93
+
94
+ def show_samples(dataset, offset=0, is_next=False):
95
+ visual_key = get_visual_key(dataset)
96
+
97
+ num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
98
+ n_samples = num_rows * num_cols
99
+
100
+ if not shuffle:
101
+ if is_next:
102
+ start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
103
+ else:
104
+ start = max(0, int(start_idx) + offset - n_samples)
105
+
106
+ st.session_state.last_start = start
107
+ end = min(start + n_samples, len(dataset))
108
+
109
+ indices = list(range(start, end))
110
+ else:
111
+ indices = random.sample(range(len(dataset)), n_samples)
112
+ samples = sample_dataset(dataset, indices)
113
+
114
+ visual_info = (
115
+ iter([resize_img_w(s[visual_key]) for s in samples])
116
+ if visual_key == "image"
117
+ # else iter([s[visual_key] for s in samples])
118
+ else iter([s["file"] for s in samples])
119
+ )
120
+ text_info = gather_items(samples, exclude=["image", "video"])
121
+ text_info = iter([format_text(s) for s in text_info])
122
+
123
+ st.markdown(
124
+ """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
125
+ unsafe_allow_html=True,
126
+ )
127
+ for _ in range(num_rows):
128
+ with st.container():
129
+ for col in st.columns(num_cols):
130
+ # col.text(next(text_info))
131
+ # col.caption(next(text_info))
132
+ try:
133
+ col.markdown(next(text_info))
134
+ if visual_key == "image":
135
+ col.image(next(visual_info), use_column_width=True, clamp=True)
136
+ elif visual_key == "video":
137
+ col.markdown(
138
+ "![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
139
+ )
140
+ except StopIteration:
141
+ break
142
+
143
+ st.markdown(
144
+ """<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
145
+ unsafe_allow_html=True,
146
+ )
147
+
148
+ st.session_state.n_display = n_samples
149
+
150
+
151
+ if __name__ == "__main__":
152
+ st.set_page_config(
153
+ page_title="LAVIS Dataset Explorer",
154
+ # layout="wide",
155
+ initial_sidebar_state="expanded",
156
+ )
157
+
158
+ dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
159
+
160
+ function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
161
+
162
+ if function == "Browser":
163
+ shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
164
+
165
+ dataset = load_dataset_cache(dataset_name)
166
+ split = st.sidebar.selectbox("Split:", dataset.keys())
167
+
168
+ dataset_len = len(dataset[split])
169
+ st.success(
170
+ f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
171
+ )
172
+
173
+ if "last_dataset" not in st.session_state:
174
+ st.session_state.last_dataset = dataset_name
175
+ st.session_state.last_split = split
176
+
177
+ if "last_start" not in st.session_state:
178
+ st.session_state.last_start = 0
179
+
180
+ if "start_idx" not in st.session_state:
181
+ st.session_state.start_idx = 0
182
+
183
+ if "shuffle" not in st.session_state:
184
+ st.session_state.shuffle = shuffle
185
+
186
+ if "first_run" not in st.session_state:
187
+ st.session_state.first_run = True
188
+ elif (
189
+ st.session_state.last_dataset != dataset_name
190
+ or st.session_state.last_split != split
191
+ ):
192
+ st.session_state.first_run = True
193
+
194
+ st.session_state.last_dataset = dataset_name
195
+ st.session_state.last_split = split
196
+ elif st.session_state.shuffle != shuffle:
197
+ st.session_state.shuffle = shuffle
198
+ st.session_state.first_run = True
199
+
200
+ if not shuffle:
201
+ n_col, p_col = st.columns([0.05, 1])
202
+
203
+ prev_button = n_col.button(PREV_STR)
204
+ next_button = p_col.button(NEXT_STR)
205
+
206
+ else:
207
+ next_button = st.button(NEXT_STR)
208
+
209
+ if not shuffle:
210
+ start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
211
+
212
+ if not start_idx.isdigit():
213
+ st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
214
+ else:
215
+ if int(start_idx) != st.session_state.start_idx:
216
+ st.session_state.start_idx = int(start_idx)
217
+ st.session_state.last_start = int(start_idx)
218
+
219
+ if prev_button:
220
+ show_samples(
221
+ dataset[split],
222
+ offset=st.session_state.last_start - st.session_state.start_idx,
223
+ is_next=False,
224
+ )
225
+
226
+ if next_button:
227
+ show_samples(
228
+ dataset[split],
229
+ offset=st.session_state.last_start - st.session_state.start_idx,
230
+ is_next=True,
231
+ )
232
+
233
+ if st.session_state.first_run:
234
+ st.session_state.first_run = False
235
+
236
+ show_samples(
237
+ dataset[split],
238
+ offset=st.session_state.last_start - st.session_state.start_idx,
239
+ is_next=True,
240
+ )
app/image_text_match.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
12
+ from lavis.processors import load_processor
13
+ from PIL import Image
14
+
15
+ from app import device, load_demo_image
16
+ from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
17
+
18
+
19
+ def app():
20
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
21
+
22
+ if model_type.startswith("BLIP"):
23
+ blip_type = model_type.split("_")[1]
24
+ model = load_blip_itm_model(device, model_type=blip_type)
25
+
26
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
27
+
28
+ st.markdown(
29
+ "<h1 style='text-align: center;'>Image Text Matching</h1>",
30
+ unsafe_allow_html=True,
31
+ )
32
+
33
+ values = list(range(1, 12))
34
+ default_layer_num = values.index(7)
35
+ layer_num = (
36
+ st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
37
+ )
38
+
39
+ instructions = """Try the provided image or upload your own:"""
40
+ file = st.file_uploader(instructions)
41
+
42
+ col1, col2 = st.columns(2)
43
+ col1.header("Image")
44
+ col2.header("GradCam")
45
+ if file:
46
+ raw_img = Image.open(file).convert("RGB")
47
+ else:
48
+ raw_img = load_demo_image()
49
+
50
+ w, h = raw_img.size
51
+ scaling_factor = 720 / w
52
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
53
+ col1.image(resized_image, use_column_width=True)
54
+
55
+ col3, col4 = st.columns(2)
56
+ col3.header("Text")
57
+ user_question = col3.text_input(
58
+ "Input your sentence!", "a woman sitting on the beach with a dog"
59
+ )
60
+ submit_button = col3.button("Submit")
61
+
62
+ col4.header("Matching score")
63
+
64
+ if submit_button:
65
+ tokenizer = init_bert_tokenizer()
66
+
67
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
68
+ text_processor = load_processor("blip_caption").build()
69
+
70
+ qry = text_processor(user_question)
71
+
72
+ norm_img = np.float32(resized_image) / 255
73
+
74
+ qry_tok = tokenizer(qry, return_tensors="pt").to(device)
75
+ gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
76
+
77
+ avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
78
+
79
+ col2.image(avg_gradcam, use_column_width=True, clamp=True)
80
+ # output = model(img, question)
81
+ itm_score = torch.nn.functional.softmax(output, dim=1)
82
+ new_title = (
83
+ '<p style="text-align: left; font-size: 25px;">\n{:.3f}%</p>'.format(
84
+ itm_score[0][1].item() * 100
85
+ )
86
+ )
87
+ col4.markdown(new_title, unsafe_allow_html=True)
app/main.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from app.multipage import MultiPage
9
+ from app import vqa, caption
10
+ from app import image_text_match as itm
11
+ from app import text_localization as tl
12
+ from app import multimodal_search as ms
13
+ from app import classification as cl
14
+
15
+
16
+ if __name__ == "__main__":
17
+ app = MultiPage()
18
+
19
+ app.add_page("Image Description Generation", caption.app)
20
+ app.add_page("Multimodal Search", ms.app)
21
+ app.add_page("Visual Question Answering", vqa.app)
22
+ app.add_page("Image Text Matching", itm.app)
23
+ app.add_page("Text Localization", tl.app)
24
+ app.add_page("Classification", cl.app)
25
+ app.run()
app/multimodal_search.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from app import cache_root, device
15
+ from app.utils import (
16
+ getAttMap,
17
+ init_bert_tokenizer,
18
+ load_blip_itm_model,
19
+ read_img,
20
+ resize_img,
21
+ )
22
+ from lavis.models import load_model
23
+ from lavis.processors import load_processor
24
+
25
+
26
+ @st.cache(
27
+ hash_funcs={
28
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
29
+ .cpu()
30
+ .numpy()
31
+ },
32
+ allow_output_mutation=True,
33
+ )
34
+ def load_feat():
35
+ from lavis.common.utils import download_url
36
+
37
+ dirname = os.path.join(os.path.dirname(__file__), "assets")
38
+ filename = "path2feat_coco_train2014.pth"
39
+ filepath = os.path.join(dirname, filename)
40
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
41
+
42
+ if not os.path.exists(filepath):
43
+ download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
44
+
45
+ path2feat = torch.load(filepath)
46
+ paths = sorted(path2feat.keys())
47
+
48
+ all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
49
+
50
+ return path2feat, paths, all_img_feats
51
+
52
+
53
+ @st.cache(
54
+ hash_funcs={
55
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
56
+ .cpu()
57
+ .numpy()
58
+ },
59
+ allow_output_mutation=True,
60
+ )
61
+ def load_feature_extractor_model(device):
62
+ model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
63
+
64
+ model = load_model(
65
+ "blip_feature_extractor", model_type="base", is_eval=True, device=device
66
+ )
67
+ model.load_from_pretrained(model_url)
68
+
69
+ return model
70
+
71
+
72
+ def app():
73
+ # === layout ===
74
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
75
+ file_root = os.path.join(cache_root, "coco/images/train2014/")
76
+
77
+ values = [12, 24, 48]
78
+ default_layer_num = values.index(24)
79
+ num_display = st.sidebar.selectbox(
80
+ "Number of images:", values, index=default_layer_num
81
+ )
82
+ show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
83
+ itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
84
+
85
+ # st.title('Multimodal Search')
86
+ st.markdown(
87
+ "<h1 style='text-align: center;'>Multimodal Search</h1>", unsafe_allow_html=True
88
+ )
89
+
90
+ # === event ===
91
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
92
+ text_processor = load_processor("blip_caption")
93
+
94
+ user_question = st.text_input(
95
+ "Search query", "A dog running on the grass.", help="Type something to search."
96
+ )
97
+ user_question = text_processor(user_question)
98
+ feature_extractor = load_feature_extractor_model(device)
99
+
100
+ # ======= ITC =========
101
+ sample = {"text_input": user_question}
102
+
103
+ with torch.no_grad():
104
+ text_feature = feature_extractor.extract_features(
105
+ sample, mode="text"
106
+ ).text_embeds_proj[0, 0]
107
+
108
+ path2feat, paths, all_img_feats = load_feat()
109
+ all_img_feats.to(device)
110
+ all_img_feats = F.normalize(all_img_feats, dim=1)
111
+
112
+ num_cols = 4
113
+ num_rows = int(num_display / num_cols)
114
+
115
+ similarities = text_feature @ all_img_feats.T
116
+ indices = torch.argsort(similarities, descending=True)[:num_display]
117
+
118
+ top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
119
+ sorted_similarities = [similarities[idx] for idx in indices]
120
+ filenames = [os.path.join(file_root, p) for p in top_paths]
121
+
122
+ # ========= ITM and GradCam ==========
123
+ bsz = 4 # max number of images to avoid cuda oom
124
+ if model_type.startswith("BLIP"):
125
+ blip_type = model_type.split("_")[1]
126
+
127
+ itm_model = load_blip_itm_model(device, model_type=blip_type)
128
+
129
+ tokenizer = init_bert_tokenizer()
130
+ queries_batch = [user_question] * bsz
131
+ queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
132
+
133
+ num_batches = int(num_display / bsz)
134
+
135
+ avg_gradcams = []
136
+ all_raw_images = []
137
+ itm_scores = []
138
+
139
+ for i in range(num_batches):
140
+ filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
141
+ raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
142
+ gradcam, itm_output = compute_gradcam_batch(
143
+ itm_model, images, queries_batch, queries_tok_batch
144
+ )
145
+
146
+ all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
147
+ norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
148
+
149
+ for norm_img, grad_cam in zip(norm_imgs, gradcam):
150
+ avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
151
+ avg_gradcams.append(avg_gradcam)
152
+
153
+ with torch.no_grad():
154
+ itm_score = torch.nn.functional.softmax(itm_output, dim=1)
155
+
156
+ itm_scores.append(itm_score)
157
+
158
+ # ========= ITM re-ranking =========
159
+ itm_scores = torch.cat(itm_scores)[:, 1]
160
+ if itm_ranking:
161
+ itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
162
+
163
+ avg_gradcams_sorted = []
164
+ all_raw_images_sorted = []
165
+ for idx in indices:
166
+ avg_gradcams_sorted.append(avg_gradcams[idx])
167
+ all_raw_images_sorted.append(all_raw_images[idx])
168
+
169
+ avg_gradcams = avg_gradcams_sorted
170
+ all_raw_images = all_raw_images_sorted
171
+
172
+ if show_gradcam:
173
+ images_to_show = iter(avg_gradcams)
174
+ else:
175
+ images_to_show = iter(all_raw_images)
176
+
177
+ for _ in range(num_rows):
178
+ with st.container():
179
+ for col in st.columns(num_cols):
180
+ col.image(next(images_to_show), use_column_width=True, clamp=True)
181
+
182
+
183
+ def read_and_process_images(image_paths, vis_processor):
184
+ raw_images = [read_img(path) for path in image_paths]
185
+ images = [vis_processor(r_img) for r_img in raw_images]
186
+ images_tensors = torch.stack(images).to(device)
187
+
188
+ return raw_images, images_tensors
189
+
190
+
191
+ def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
192
+ model.text_encoder.base_model.base_model.encoder.layer[
193
+ block_num
194
+ ].crossattention.self.save_attention = True
195
+
196
+ output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
197
+ loss = output[:, 1].sum()
198
+
199
+ model.zero_grad()
200
+ loss.backward()
201
+ with torch.no_grad():
202
+ mask = tokenized_text.attention_mask.view(
203
+ tokenized_text.attention_mask.size(0), 1, -1, 1, 1
204
+ ) # (bsz,1,token_len, 1,1)
205
+ token_length = mask.sum() - 2
206
+ token_length = token_length.cpu()
207
+ # grads and cams [bsz, num_head, seq_len, image_patch]
208
+ grads = model.text_encoder.base_model.base_model.encoder.layer[
209
+ block_num
210
+ ].crossattention.self.get_attn_gradients()
211
+ cams = model.text_encoder.base_model.base_model.encoder.layer[
212
+ block_num
213
+ ].crossattention.self.get_attention_map()
214
+
215
+ # assume using vit large with 576 num image patch
216
+ cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
217
+ grads = (
218
+ grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
219
+ * mask
220
+ )
221
+
222
+ gradcam = cams * grads
223
+ # [enc token gradcam, average gradcam across token, gradcam for individual token]
224
+ # gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
225
+ gradcam = gradcam.mean(1).cpu().detach()
226
+ gradcam = (
227
+ gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
228
+ )
229
+
230
+ return gradcam, output
app/multipage.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ """
9
+ This file is the framework for generating multiple Streamlit applications
10
+ through an object oriented framework.
11
+ """
12
+
13
+ # Import necessary libraries
14
+ import streamlit as st
15
+
16
+ # Define the multipage class to manage the multiple apps in our program
17
+ class MultiPage:
18
+ """Framework for combining multiple streamlit applications."""
19
+
20
+ def __init__(self) -> None:
21
+ """Constructor class to generate a list which will store all our applications as an instance variable."""
22
+ self.pages = []
23
+
24
+ def add_page(self, title, func) -> None:
25
+ """Class Method to Add pages to the project
26
+ Args:
27
+ title ([str]): The title of page which we are adding to the list of apps
28
+
29
+ func: Python function to render this page in Streamlit
30
+ """
31
+
32
+ self.pages.append({"title": title, "function": func})
33
+
34
+ def run(self):
35
+ # Drodown to select the page to run
36
+ page = st.sidebar.selectbox(
37
+ "Navigation", self.pages, format_func=lambda page: page["title"]
38
+ )
39
+
40
+ # run the app function
41
+ page["function"]()
app/text_localization.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
13
+ from lavis.processors import load_processor
14
+ from PIL import Image
15
+
16
+ from app import device, load_demo_image
17
+ from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
18
+
19
+
20
+ def app():
21
+ model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
22
+
23
+ values = list(range(1, 12))
24
+ default_layer_num = values.index(7)
25
+ layer_num = (
26
+ st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
27
+ )
28
+
29
+ st.markdown(
30
+ "<h1 style='text-align: center;'>Text Localization</h1>", unsafe_allow_html=True
31
+ )
32
+
33
+ vis_processor = load_processor("blip_image_eval").build(image_size=384)
34
+ text_processor = load_processor("blip_caption")
35
+
36
+ tokenizer = init_bert_tokenizer()
37
+
38
+ instructions = "Try the provided image and text or use your own ones."
39
+ file = st.file_uploader(instructions)
40
+
41
+ query = st.text_input(
42
+ "Try a different input.", "A girl playing with her dog on the beach."
43
+ )
44
+
45
+ submit_button = st.button("Submit")
46
+
47
+ col1, col2 = st.columns(2)
48
+
49
+ if file:
50
+ raw_img = Image.open(file).convert("RGB")
51
+ else:
52
+ raw_img = load_demo_image()
53
+
54
+ col1.header("Image")
55
+ w, h = raw_img.size
56
+ scaling_factor = 720 / w
57
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
58
+ col1.image(resized_image, use_column_width=True)
59
+
60
+ col2.header("GradCam")
61
+
62
+ if submit_button:
63
+ if model_type.startswith("BLIP"):
64
+ blip_type = model_type.split("_")[1]
65
+ model = load_blip_itm_model(device, model_type=blip_type)
66
+
67
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
68
+ qry = text_processor(query)
69
+
70
+ qry_tok = tokenizer(qry, return_tensors="pt").to(device)
71
+
72
+ norm_img = np.float32(resized_image) / 255
73
+
74
+ gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
75
+
76
+ avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
77
+ col2.image(avg_gradcam, use_column_width=True, clamp=True)
78
+
79
+ num_cols = 4.0
80
+ num_tokens = len(qry_tok.input_ids[0]) - 2
81
+
82
+ num_rows = int(math.ceil(num_tokens / num_cols))
83
+
84
+ gradcam_iter = iter(gradcam[0][2:-1])
85
+ token_id_iter = iter(qry_tok.input_ids[0][1:-1])
86
+
87
+ for _ in range(num_rows):
88
+ with st.container():
89
+ for col in st.columns(int(num_cols)):
90
+ token_id = next(token_id_iter, None)
91
+ if not token_id:
92
+ break
93
+ gradcam_img = next(gradcam_iter)
94
+
95
+ word = tokenizer.decode([token_id])
96
+ gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True)
97
+
98
+ new_title = (
99
+ '<p style="text-align: center; font-size: 25px;">{}</p>'.format(
100
+ word
101
+ )
102
+ )
103
+ col.markdown(new_title, unsafe_allow_html=True)
104
+ # st.image(image, channels="BGR")
105
+ col.image(gradcam_todraw, use_column_width=True, clamp=True)
app/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import numpy as np
9
+ import streamlit as st
10
+ import torch
11
+ from lavis.models import BlipBase, load_model
12
+ from matplotlib import pyplot as plt
13
+ from PIL import Image
14
+ from scipy.ndimage import filters
15
+ from skimage import transform as skimage_transform
16
+
17
+
18
+ def resize_img(raw_img):
19
+ w, h = raw_img.size
20
+ scaling_factor = 240 / w
21
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
22
+ return resized_image
23
+
24
+
25
+ def read_img(filepath):
26
+ raw_image = Image.open(filepath).convert("RGB")
27
+
28
+ return raw_image
29
+
30
+
31
+ @st.cache(
32
+ hash_funcs={
33
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
34
+ .cpu()
35
+ .numpy()
36
+ },
37
+ allow_output_mutation=True,
38
+ )
39
+ def load_model_cache(name, model_type, is_eval, device):
40
+ return load_model(name, model_type, is_eval, device)
41
+
42
+
43
+ @st.cache(allow_output_mutation=True)
44
+ def init_bert_tokenizer():
45
+ tokenizer = BlipBase.init_tokenizer()
46
+ return tokenizer
47
+
48
+
49
+ def getAttMap(img, attMap, blur=True, overlap=True):
50
+ attMap -= attMap.min()
51
+ if attMap.max() > 0:
52
+ attMap /= attMap.max()
53
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
54
+ if blur:
55
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
56
+ attMap -= attMap.min()
57
+ attMap /= attMap.max()
58
+ cmap = plt.get_cmap("jet")
59
+ attMapV = cmap(attMap)
60
+ attMapV = np.delete(attMapV, 3, 2)
61
+ if overlap:
62
+ attMap = (
63
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
64
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
65
+ )
66
+ return attMap
67
+
68
+
69
+ @st.cache(
70
+ hash_funcs={
71
+ torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
72
+ .cpu()
73
+ .numpy()
74
+ },
75
+ allow_output_mutation=True,
76
+ )
77
+ def load_blip_itm_model(device, model_type="base"):
78
+ model = load_model(
79
+ "blip_image_text_matching", model_type, is_eval=True, device=device
80
+ )
81
+ return model
app/vqa.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) 2022, salesforce.com, inc.
3
+ # All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import streamlit as st
9
+ from app import load_demo_image, device
10
+ from app.utils import load_model_cache
11
+ from lavis.processors import load_processor
12
+ from PIL import Image
13
+
14
+
15
+ def app():
16
+ model_type = st.sidebar.selectbox("Model:", ["BLIP"])
17
+
18
+ # ===== layout =====
19
+ st.markdown(
20
+ "<h1 style='text-align: center;'>Visual Question Answering</h1>",
21
+ unsafe_allow_html=True,
22
+ )
23
+
24
+ instructions = """Try the provided image or upload your own:"""
25
+ file = st.file_uploader(instructions)
26
+
27
+ col1, col2 = st.columns(2)
28
+
29
+ col1.header("Image")
30
+ if file:
31
+ raw_img = Image.open(file).convert("RGB")
32
+ else:
33
+ raw_img = load_demo_image()
34
+
35
+ w, h = raw_img.size
36
+ scaling_factor = 720 / w
37
+ resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
38
+
39
+ col1.image(resized_image, use_column_width=True)
40
+ col2.header("Question")
41
+
42
+ user_question = col2.text_input("Input your question!", "What are objects there?")
43
+ qa_button = st.button("Submit")
44
+
45
+ col2.header("Answer")
46
+
47
+ # ===== event =====
48
+ vis_processor = load_processor("blip_image_eval").build(image_size=480)
49
+ text_processor = load_processor("blip_question").build()
50
+
51
+ if qa_button:
52
+ if model_type.startswith("BLIP"):
53
+ model = load_model_cache(
54
+ "blip_vqa", model_type="vqav2", is_eval=True, device=device
55
+ )
56
+
57
+ img = vis_processor(raw_img).unsqueeze(0).to(device)
58
+ question = text_processor(user_question)
59
+
60
+ vqa_samples = {"image": img, "text_input": [question]}
61
+ answers = model.predict_answers(vqa_samples, inference_method="generate")
62
+
63
+ col2.write("\n".join(answers), use_column_width=True)