Upload 12 files
Browse files- app/__init__.py +26 -0
- app/calculate_coco_features.py +87 -0
- app/caption.py +98 -0
- app/classification.py +216 -0
- app/dataset_browser.py +240 -0
- app/image_text_match.py +87 -0
- app/main.py +25 -0
- app/multimodal_search.py +230 -0
- app/multipage.py +41 -0
- app/text_localization.py +105 -0
- app/utils.py +81 -0
- app/vqa.py +63 -0
app/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import requests
|
10 |
+
|
11 |
+
import streamlit as st
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
@st.cache()
|
16 |
+
def load_demo_image():
|
17 |
+
img_url = (
|
18 |
+
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
19 |
+
)
|
20 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
21 |
+
return raw_image
|
22 |
+
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
+
cache_root = "/export/home/.cache/lavis/"
|
app/calculate_coco_features.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import requests
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
|
14 |
+
from lavis.common.registry import registry
|
15 |
+
from lavis.processors import *
|
16 |
+
from lavis.models import *
|
17 |
+
from lavis.common.utils import build_default_model
|
18 |
+
|
19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
|
21 |
+
|
22 |
+
def load_demo_image():
|
23 |
+
img_url = (
|
24 |
+
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
|
25 |
+
)
|
26 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
27 |
+
|
28 |
+
return raw_image
|
29 |
+
|
30 |
+
|
31 |
+
def read_img(filepath):
|
32 |
+
raw_image = Image.open(filepath).convert("RGB")
|
33 |
+
|
34 |
+
return raw_image
|
35 |
+
|
36 |
+
|
37 |
+
# model
|
38 |
+
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
|
39 |
+
feature_extractor = BlipFeatureExtractor(pretrained=model_url)
|
40 |
+
|
41 |
+
feature_extractor.eval()
|
42 |
+
feature_extractor = feature_extractor.to(device)
|
43 |
+
|
44 |
+
# preprocessors
|
45 |
+
vis_processor = BlipImageEvalProcessor(image_size=224)
|
46 |
+
text_processor = BlipCaptionProcessor()
|
47 |
+
|
48 |
+
# files to process
|
49 |
+
# file_root = "/export/home/.cache/lavis/coco/images/val2014"
|
50 |
+
file_root = "/export/home/.cache/lavis/coco/images/train2014"
|
51 |
+
filepaths = os.listdir(file_root)
|
52 |
+
|
53 |
+
print(len(filepaths))
|
54 |
+
|
55 |
+
caption = "dummy"
|
56 |
+
|
57 |
+
path2feat = dict()
|
58 |
+
bsz = 256
|
59 |
+
|
60 |
+
images_in_batch = []
|
61 |
+
filepaths_in_batch = []
|
62 |
+
|
63 |
+
for i, filename in enumerate(filepaths):
|
64 |
+
if i % bsz == 0 and i > 0:
|
65 |
+
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
|
66 |
+
with torch.no_grad():
|
67 |
+
image_features = feature_extractor(
|
68 |
+
images_in_batch, caption, mode="image", normalized=True
|
69 |
+
)[:, 0]
|
70 |
+
|
71 |
+
for filepath, image_feat in zip(filepaths_in_batch, image_features):
|
72 |
+
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
|
73 |
+
|
74 |
+
images_in_batch = []
|
75 |
+
filepaths_in_batch = []
|
76 |
+
|
77 |
+
print(len(path2feat), image_features.shape)
|
78 |
+
else:
|
79 |
+
filepath = os.path.join(file_root, filename)
|
80 |
+
|
81 |
+
image = read_img(filepath)
|
82 |
+
image = vis_processor(image).unsqueeze(0)
|
83 |
+
|
84 |
+
images_in_batch.append(image)
|
85 |
+
filepaths_in_batch.append(filepath)
|
86 |
+
|
87 |
+
torch.save(path2feat, "path2feat_coco_train2014.pth")
|
app/caption.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import streamlit as st
|
9 |
+
from app import device, load_demo_image
|
10 |
+
from app.utils import load_model_cache
|
11 |
+
from lavis.processors import load_processor
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
def app():
|
16 |
+
# ===== layout =====
|
17 |
+
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
|
18 |
+
|
19 |
+
sampling_method = st.sidebar.selectbox(
|
20 |
+
"Sampling method:", ["Beam search", "Nucleus sampling"]
|
21 |
+
)
|
22 |
+
|
23 |
+
st.markdown(
|
24 |
+
"<h1 style='text-align: center;'>Image Description Generation</h1>",
|
25 |
+
unsafe_allow_html=True,
|
26 |
+
)
|
27 |
+
|
28 |
+
instructions = """Try the provided image or upload your own:"""
|
29 |
+
file = st.file_uploader(instructions)
|
30 |
+
|
31 |
+
use_beam = sampling_method == "Beam search"
|
32 |
+
|
33 |
+
col1, col2 = st.columns(2)
|
34 |
+
|
35 |
+
if file:
|
36 |
+
raw_img = Image.open(file).convert("RGB")
|
37 |
+
else:
|
38 |
+
raw_img = load_demo_image()
|
39 |
+
|
40 |
+
col1.header("Image")
|
41 |
+
|
42 |
+
w, h = raw_img.size
|
43 |
+
scaling_factor = 720 / w
|
44 |
+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
45 |
+
|
46 |
+
col1.image(resized_image, use_column_width=True)
|
47 |
+
col2.header("Description")
|
48 |
+
|
49 |
+
cap_button = st.button("Generate")
|
50 |
+
|
51 |
+
# ==== event ====
|
52 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
53 |
+
|
54 |
+
if cap_button:
|
55 |
+
if model_type.startswith("BLIP"):
|
56 |
+
blip_type = model_type.split("_")[1].lower()
|
57 |
+
model = load_model_cache(
|
58 |
+
"blip_caption",
|
59 |
+
model_type=f"{blip_type}_coco",
|
60 |
+
is_eval=True,
|
61 |
+
device=device,
|
62 |
+
)
|
63 |
+
|
64 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
65 |
+
captions = generate_caption(
|
66 |
+
model=model, image=img, use_nucleus_sampling=not use_beam
|
67 |
+
)
|
68 |
+
|
69 |
+
col2.write("\n\n".join(captions), use_column_width=True)
|
70 |
+
|
71 |
+
|
72 |
+
def generate_caption(
|
73 |
+
model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
|
74 |
+
):
|
75 |
+
samples = {"image": image}
|
76 |
+
|
77 |
+
captions = []
|
78 |
+
if use_nucleus_sampling:
|
79 |
+
for _ in range(5):
|
80 |
+
caption = model.generate(
|
81 |
+
samples,
|
82 |
+
use_nucleus_sampling=True,
|
83 |
+
max_length=max_length,
|
84 |
+
min_length=min_length,
|
85 |
+
top_p=0.9,
|
86 |
+
)
|
87 |
+
captions.append(caption[0])
|
88 |
+
else:
|
89 |
+
caption = model.generate(
|
90 |
+
samples,
|
91 |
+
use_nucleus_sampling=False,
|
92 |
+
num_beams=num_beams,
|
93 |
+
max_length=max_length,
|
94 |
+
min_length=min_length,
|
95 |
+
)
|
96 |
+
captions.append(caption[0])
|
97 |
+
|
98 |
+
return captions
|
app/classification.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
import requests
|
10 |
+
import streamlit as st
|
11 |
+
import torch
|
12 |
+
from lavis.models import load_model
|
13 |
+
from lavis.processors import load_processor
|
14 |
+
from lavis.processors.blip_processors import BlipCaptionProcessor
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
from app import device, load_demo_image
|
18 |
+
from app.utils import load_blip_itm_model
|
19 |
+
from lavis.processors.clip_processors import ClipImageEvalProcessor
|
20 |
+
|
21 |
+
|
22 |
+
@st.cache()
|
23 |
+
def load_demo_image(img_url=None):
|
24 |
+
if not img_url:
|
25 |
+
img_url = "https://img.atlasobscura.com/yDJ86L8Ou6aIjBsxnlAy5f164w1rjTgcHZcx2yUs4mo/rt:fit/w:1200/q:81/sm:1/scp:1/ar:1/aHR0cHM6Ly9hdGxh/cy1kZXYuczMuYW1h/em9uYXdzLmNvbS91/cGxvYWRzL3BsYWNl/X2ltYWdlcy85MDll/MDRjOS00NTJjLTQx/NzQtYTY4MS02NmQw/MzI2YWIzNjk1ZGVk/MGZhMTJiMTM5MmZi/NGFfUmVhcl92aWV3/X29mX3RoZV9NZXJs/aW9uX3N0YXR1ZV9h/dF9NZXJsaW9uX1Bh/cmssX1NpbmdhcG9y/ZSxfd2l0aF9NYXJp/bmFfQmF5X1NhbmRz/X2luX3RoZV9kaXN0/YW5jZV8tXzIwMTQw/MzA3LmpwZw.jpg"
|
26 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
27 |
+
return raw_image
|
28 |
+
|
29 |
+
|
30 |
+
@st.cache(
|
31 |
+
hash_funcs={
|
32 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
33 |
+
.cpu()
|
34 |
+
.numpy()
|
35 |
+
},
|
36 |
+
allow_output_mutation=True,
|
37 |
+
)
|
38 |
+
def load_model_cache(model_type, device):
|
39 |
+
if model_type == "blip":
|
40 |
+
model = load_model(
|
41 |
+
"blip_feature_extractor", model_type="base", is_eval=True, device=device
|
42 |
+
)
|
43 |
+
elif model_type == "albef":
|
44 |
+
model = load_model(
|
45 |
+
"albef_feature_extractor", model_type="base", is_eval=True, device=device
|
46 |
+
)
|
47 |
+
elif model_type == "CLIP_ViT-B-32":
|
48 |
+
model = load_model(
|
49 |
+
"clip_feature_extractor", "ViT-B-32", is_eval=True, device=device
|
50 |
+
)
|
51 |
+
elif model_type == "CLIP_ViT-B-16":
|
52 |
+
model = load_model(
|
53 |
+
"clip_feature_extractor", "ViT-B-16", is_eval=True, device=device
|
54 |
+
)
|
55 |
+
elif model_type == "CLIP_ViT-L-14":
|
56 |
+
model = load_model(
|
57 |
+
"clip_feature_extractor", "ViT-L-14", is_eval=True, device=device
|
58 |
+
)
|
59 |
+
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def app():
|
64 |
+
model_type = st.sidebar.selectbox(
|
65 |
+
"Model:",
|
66 |
+
["ALBEF", "BLIP_Base", "CLIP_ViT-B-32", "CLIP_ViT-B-16", "CLIP_ViT-L-14"],
|
67 |
+
)
|
68 |
+
score_type = st.sidebar.selectbox("Score type:", ["Cosine", "Multimodal"])
|
69 |
+
|
70 |
+
# ===== layout =====
|
71 |
+
st.markdown(
|
72 |
+
"<h1 style='text-align: center;'>Zero-shot Classification</h1>",
|
73 |
+
unsafe_allow_html=True,
|
74 |
+
)
|
75 |
+
|
76 |
+
instructions = """Try the provided image or upload your own:"""
|
77 |
+
file = st.file_uploader(instructions)
|
78 |
+
|
79 |
+
st.header("Image")
|
80 |
+
if file:
|
81 |
+
raw_img = Image.open(file).convert("RGB")
|
82 |
+
else:
|
83 |
+
raw_img = load_demo_image()
|
84 |
+
|
85 |
+
st.image(raw_img) # , use_column_width=True)
|
86 |
+
|
87 |
+
col1, col2 = st.columns(2)
|
88 |
+
|
89 |
+
col1.header("Categories")
|
90 |
+
|
91 |
+
cls_0 = col1.text_input("category 1", value="merlion")
|
92 |
+
cls_1 = col1.text_input("category 2", value="sky")
|
93 |
+
cls_2 = col1.text_input("category 3", value="giraffe")
|
94 |
+
cls_3 = col1.text_input("category 4", value="fountain")
|
95 |
+
cls_4 = col1.text_input("category 5", value="marina bay")
|
96 |
+
|
97 |
+
cls_names = [cls_0, cls_1, cls_2, cls_3, cls_4]
|
98 |
+
cls_names = [cls_nm for cls_nm in cls_names if len(cls_nm) > 0]
|
99 |
+
|
100 |
+
if len(cls_names) != len(set(cls_names)):
|
101 |
+
st.error("Please provide unique class names")
|
102 |
+
return
|
103 |
+
|
104 |
+
button = st.button("Submit")
|
105 |
+
|
106 |
+
col2.header("Prediction")
|
107 |
+
|
108 |
+
# ===== event =====
|
109 |
+
|
110 |
+
if button:
|
111 |
+
if model_type.startswith("BLIP"):
|
112 |
+
text_processor = BlipCaptionProcessor(prompt="A picture of ")
|
113 |
+
cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
|
114 |
+
|
115 |
+
if score_type == "Cosine":
|
116 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=224)
|
117 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
118 |
+
|
119 |
+
feature_extractor = load_model_cache(model_type="blip", device=device)
|
120 |
+
|
121 |
+
sample = {"image": img, "text_input": cls_prompt}
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
image_features = feature_extractor.extract_features(
|
125 |
+
sample, mode="image"
|
126 |
+
).image_embeds_proj[:, 0]
|
127 |
+
text_features = feature_extractor.extract_features(
|
128 |
+
sample, mode="text"
|
129 |
+
).text_embeds_proj[:, 0]
|
130 |
+
sims = (image_features @ text_features.t())[
|
131 |
+
0
|
132 |
+
] / feature_extractor.temp
|
133 |
+
|
134 |
+
else:
|
135 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
136 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
137 |
+
|
138 |
+
model = load_blip_itm_model(device)
|
139 |
+
|
140 |
+
output = model(img, cls_prompt, match_head="itm")
|
141 |
+
sims = output[:, 1]
|
142 |
+
|
143 |
+
sims = torch.nn.Softmax(dim=0)(sims)
|
144 |
+
inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
|
145 |
+
|
146 |
+
elif model_type.startswith("ALBEF"):
|
147 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=224)
|
148 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
149 |
+
|
150 |
+
text_processor = BlipCaptionProcessor(prompt="A picture of ")
|
151 |
+
cls_prompt = [text_processor(cls_nm) for cls_nm in cls_names]
|
152 |
+
|
153 |
+
feature_extractor = load_model_cache(model_type="albef", device=device)
|
154 |
+
|
155 |
+
sample = {"image": img, "text_input": cls_prompt}
|
156 |
+
|
157 |
+
with torch.no_grad():
|
158 |
+
image_features = feature_extractor.extract_features(
|
159 |
+
sample, mode="image"
|
160 |
+
).image_embeds_proj[:, 0]
|
161 |
+
text_features = feature_extractor.extract_features(
|
162 |
+
sample, mode="text"
|
163 |
+
).text_embeds_proj[:, 0]
|
164 |
+
|
165 |
+
st.write(image_features.shape)
|
166 |
+
st.write(text_features.shape)
|
167 |
+
|
168 |
+
sims = (image_features @ text_features.t())[0] / feature_extractor.temp
|
169 |
+
|
170 |
+
sims = torch.nn.Softmax(dim=0)(sims)
|
171 |
+
inv_sims = [sim * 100 for sim in sims.tolist()[::-1]]
|
172 |
+
|
173 |
+
elif model_type.startswith("CLIP"):
|
174 |
+
if model_type == "CLIP_ViT-B-32":
|
175 |
+
model = load_model_cache(model_type="CLIP_ViT-B-32", device=device)
|
176 |
+
elif model_type == "CLIP_ViT-B-16":
|
177 |
+
model = load_model_cache(model_type="CLIP_ViT-B-16", device=device)
|
178 |
+
elif model_type == "CLIP_ViT-L-14":
|
179 |
+
model = load_model_cache(model_type="CLIP_ViT-L-14", device=device)
|
180 |
+
else:
|
181 |
+
raise ValueError(f"Unknown model type {model_type}")
|
182 |
+
|
183 |
+
if score_type == "Cosine":
|
184 |
+
# image_preprocess = ClipImageEvalProcessor(image_size=336)
|
185 |
+
image_preprocess = ClipImageEvalProcessor(image_size=224)
|
186 |
+
img = image_preprocess(raw_img).unsqueeze(0).to(device)
|
187 |
+
|
188 |
+
sample = {"image": img, "text_input": cls_names}
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
clip_features = model.extract_features(sample)
|
192 |
+
|
193 |
+
image_features = clip_features.image_embeds_proj
|
194 |
+
text_features = clip_features.text_embeds_proj
|
195 |
+
|
196 |
+
sims = (100.0 * image_features @ text_features.T)[0].softmax(dim=-1)
|
197 |
+
inv_sims = sims.tolist()[::-1]
|
198 |
+
else:
|
199 |
+
st.warning("CLIP does not support multimodal scoring.")
|
200 |
+
return
|
201 |
+
|
202 |
+
fig = go.Figure(
|
203 |
+
go.Bar(
|
204 |
+
x=inv_sims,
|
205 |
+
y=cls_names[::-1],
|
206 |
+
text=["{:.2f}".format(s) for s in inv_sims],
|
207 |
+
orientation="h",
|
208 |
+
)
|
209 |
+
)
|
210 |
+
fig.update_traces(
|
211 |
+
textfont_size=12,
|
212 |
+
textangle=0,
|
213 |
+
textposition="outside",
|
214 |
+
cliponaxis=False,
|
215 |
+
)
|
216 |
+
col2.plotly_chart(fig, use_container_width=True)
|
app/dataset_browser.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import random
|
9 |
+
from collections import OrderedDict
|
10 |
+
from functools import reduce
|
11 |
+
from tkinter import N
|
12 |
+
|
13 |
+
import streamlit as st
|
14 |
+
from lavis.common.registry import registry
|
15 |
+
from lavis.datasets.builders import dataset_zoo, load_dataset
|
16 |
+
from lavis.datasets.builders.base_dataset_builder import load_dataset_config
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
IMAGE_LAYOUT = 3, 4
|
20 |
+
VIDEO_LAYOUT = 1, 2
|
21 |
+
|
22 |
+
PREV_STR = "Prev"
|
23 |
+
NEXT_STR = "Next"
|
24 |
+
|
25 |
+
|
26 |
+
def sample_dataset(dataset, indices):
|
27 |
+
samples = [dataset.displ_item(idx) for idx in indices]
|
28 |
+
|
29 |
+
return samples
|
30 |
+
|
31 |
+
|
32 |
+
def get_concat_v(im1, im2):
|
33 |
+
margin = 5
|
34 |
+
|
35 |
+
canvas_size = (im1.width + im2.width + margin, max(im1.height, im2.height))
|
36 |
+
canvas = Image.new("RGB", canvas_size, "White")
|
37 |
+
canvas.paste(im1, (0, 0))
|
38 |
+
canvas.paste(im2, (im1.width + margin, 0))
|
39 |
+
|
40 |
+
return canvas
|
41 |
+
|
42 |
+
|
43 |
+
def resize_img_w(raw_img, new_w=224):
|
44 |
+
if isinstance(raw_img, list):
|
45 |
+
resized_imgs = [resize_img_w(img, 196) for img in raw_img]
|
46 |
+
# concatenate images
|
47 |
+
resized_image = reduce(get_concat_v, resized_imgs)
|
48 |
+
else:
|
49 |
+
w, h = raw_img.size
|
50 |
+
scaling_factor = new_w / w
|
51 |
+
resized_image = raw_img.resize(
|
52 |
+
(int(w * scaling_factor), int(h * scaling_factor))
|
53 |
+
)
|
54 |
+
|
55 |
+
return resized_image
|
56 |
+
|
57 |
+
|
58 |
+
def get_visual_key(dataset):
|
59 |
+
if "image" in dataset[0]:
|
60 |
+
return "image"
|
61 |
+
elif "image0" in dataset[0]: # NLVR2 dataset
|
62 |
+
return "image"
|
63 |
+
elif "video" in dataset[0]:
|
64 |
+
return "video"
|
65 |
+
else:
|
66 |
+
raise ValueError("Visual key not found.")
|
67 |
+
|
68 |
+
|
69 |
+
def gather_items(samples, exclude=[]):
|
70 |
+
gathered = []
|
71 |
+
|
72 |
+
for s in samples:
|
73 |
+
ns = OrderedDict()
|
74 |
+
for k in s.keys():
|
75 |
+
if k not in exclude:
|
76 |
+
ns[k] = s[k]
|
77 |
+
|
78 |
+
gathered.append(ns)
|
79 |
+
|
80 |
+
return gathered
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache(allow_output_mutation=True)
|
84 |
+
def load_dataset_cache(name):
|
85 |
+
return load_dataset(name)
|
86 |
+
|
87 |
+
|
88 |
+
def format_text(text):
|
89 |
+
md = "\n\n".join([f"**{k}**: {v}" for k, v in text.items()])
|
90 |
+
|
91 |
+
return md
|
92 |
+
|
93 |
+
|
94 |
+
def show_samples(dataset, offset=0, is_next=False):
|
95 |
+
visual_key = get_visual_key(dataset)
|
96 |
+
|
97 |
+
num_rows, num_cols = IMAGE_LAYOUT if visual_key == "image" else VIDEO_LAYOUT
|
98 |
+
n_samples = num_rows * num_cols
|
99 |
+
|
100 |
+
if not shuffle:
|
101 |
+
if is_next:
|
102 |
+
start = min(int(start_idx) + offset + n_samples, len(dataset) - n_samples)
|
103 |
+
else:
|
104 |
+
start = max(0, int(start_idx) + offset - n_samples)
|
105 |
+
|
106 |
+
st.session_state.last_start = start
|
107 |
+
end = min(start + n_samples, len(dataset))
|
108 |
+
|
109 |
+
indices = list(range(start, end))
|
110 |
+
else:
|
111 |
+
indices = random.sample(range(len(dataset)), n_samples)
|
112 |
+
samples = sample_dataset(dataset, indices)
|
113 |
+
|
114 |
+
visual_info = (
|
115 |
+
iter([resize_img_w(s[visual_key]) for s in samples])
|
116 |
+
if visual_key == "image"
|
117 |
+
# else iter([s[visual_key] for s in samples])
|
118 |
+
else iter([s["file"] for s in samples])
|
119 |
+
)
|
120 |
+
text_info = gather_items(samples, exclude=["image", "video"])
|
121 |
+
text_info = iter([format_text(s) for s in text_info])
|
122 |
+
|
123 |
+
st.markdown(
|
124 |
+
"""<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
|
125 |
+
unsafe_allow_html=True,
|
126 |
+
)
|
127 |
+
for _ in range(num_rows):
|
128 |
+
with st.container():
|
129 |
+
for col in st.columns(num_cols):
|
130 |
+
# col.text(next(text_info))
|
131 |
+
# col.caption(next(text_info))
|
132 |
+
try:
|
133 |
+
col.markdown(next(text_info))
|
134 |
+
if visual_key == "image":
|
135 |
+
col.image(next(visual_info), use_column_width=True, clamp=True)
|
136 |
+
elif visual_key == "video":
|
137 |
+
col.markdown(
|
138 |
+
"![Alt Text](https://media.giphy.com/media/vFKqnCdLPNOKc/giphy.gif)"
|
139 |
+
)
|
140 |
+
except StopIteration:
|
141 |
+
break
|
142 |
+
|
143 |
+
st.markdown(
|
144 |
+
"""<hr style="height:1px;border:none;color:#c7ccd4;background-color:#c7ccd4;"/> """,
|
145 |
+
unsafe_allow_html=True,
|
146 |
+
)
|
147 |
+
|
148 |
+
st.session_state.n_display = n_samples
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
st.set_page_config(
|
153 |
+
page_title="LAVIS Dataset Explorer",
|
154 |
+
# layout="wide",
|
155 |
+
initial_sidebar_state="expanded",
|
156 |
+
)
|
157 |
+
|
158 |
+
dataset_name = st.sidebar.selectbox("Dataset:", dataset_zoo.get_names())
|
159 |
+
|
160 |
+
function = st.sidebar.selectbox("Function:", ["Browser"], index=0)
|
161 |
+
|
162 |
+
if function == "Browser":
|
163 |
+
shuffle = st.sidebar.selectbox("Shuffled:", [True, False], index=0)
|
164 |
+
|
165 |
+
dataset = load_dataset_cache(dataset_name)
|
166 |
+
split = st.sidebar.selectbox("Split:", dataset.keys())
|
167 |
+
|
168 |
+
dataset_len = len(dataset[split])
|
169 |
+
st.success(
|
170 |
+
f"Loaded {dataset_name}/{split} with **{dataset_len}** records. **Image/video directory**: {dataset[split].vis_root}"
|
171 |
+
)
|
172 |
+
|
173 |
+
if "last_dataset" not in st.session_state:
|
174 |
+
st.session_state.last_dataset = dataset_name
|
175 |
+
st.session_state.last_split = split
|
176 |
+
|
177 |
+
if "last_start" not in st.session_state:
|
178 |
+
st.session_state.last_start = 0
|
179 |
+
|
180 |
+
if "start_idx" not in st.session_state:
|
181 |
+
st.session_state.start_idx = 0
|
182 |
+
|
183 |
+
if "shuffle" not in st.session_state:
|
184 |
+
st.session_state.shuffle = shuffle
|
185 |
+
|
186 |
+
if "first_run" not in st.session_state:
|
187 |
+
st.session_state.first_run = True
|
188 |
+
elif (
|
189 |
+
st.session_state.last_dataset != dataset_name
|
190 |
+
or st.session_state.last_split != split
|
191 |
+
):
|
192 |
+
st.session_state.first_run = True
|
193 |
+
|
194 |
+
st.session_state.last_dataset = dataset_name
|
195 |
+
st.session_state.last_split = split
|
196 |
+
elif st.session_state.shuffle != shuffle:
|
197 |
+
st.session_state.shuffle = shuffle
|
198 |
+
st.session_state.first_run = True
|
199 |
+
|
200 |
+
if not shuffle:
|
201 |
+
n_col, p_col = st.columns([0.05, 1])
|
202 |
+
|
203 |
+
prev_button = n_col.button(PREV_STR)
|
204 |
+
next_button = p_col.button(NEXT_STR)
|
205 |
+
|
206 |
+
else:
|
207 |
+
next_button = st.button(NEXT_STR)
|
208 |
+
|
209 |
+
if not shuffle:
|
210 |
+
start_idx = st.sidebar.text_input(f"Begin from (total {dataset_len})", 0)
|
211 |
+
|
212 |
+
if not start_idx.isdigit():
|
213 |
+
st.error(f"Input to 'Begin from' must be digits, found {start_idx}.")
|
214 |
+
else:
|
215 |
+
if int(start_idx) != st.session_state.start_idx:
|
216 |
+
st.session_state.start_idx = int(start_idx)
|
217 |
+
st.session_state.last_start = int(start_idx)
|
218 |
+
|
219 |
+
if prev_button:
|
220 |
+
show_samples(
|
221 |
+
dataset[split],
|
222 |
+
offset=st.session_state.last_start - st.session_state.start_idx,
|
223 |
+
is_next=False,
|
224 |
+
)
|
225 |
+
|
226 |
+
if next_button:
|
227 |
+
show_samples(
|
228 |
+
dataset[split],
|
229 |
+
offset=st.session_state.last_start - st.session_state.start_idx,
|
230 |
+
is_next=True,
|
231 |
+
)
|
232 |
+
|
233 |
+
if st.session_state.first_run:
|
234 |
+
st.session_state.first_run = False
|
235 |
+
|
236 |
+
show_samples(
|
237 |
+
dataset[split],
|
238 |
+
offset=st.session_state.last_start - st.session_state.start_idx,
|
239 |
+
is_next=True,
|
240 |
+
)
|
app/image_text_match.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import streamlit as st
|
10 |
+
import torch
|
11 |
+
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
|
12 |
+
from lavis.processors import load_processor
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from app import device, load_demo_image
|
16 |
+
from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
|
17 |
+
|
18 |
+
|
19 |
+
def app():
|
20 |
+
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
|
21 |
+
|
22 |
+
if model_type.startswith("BLIP"):
|
23 |
+
blip_type = model_type.split("_")[1]
|
24 |
+
model = load_blip_itm_model(device, model_type=blip_type)
|
25 |
+
|
26 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
27 |
+
|
28 |
+
st.markdown(
|
29 |
+
"<h1 style='text-align: center;'>Image Text Matching</h1>",
|
30 |
+
unsafe_allow_html=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
values = list(range(1, 12))
|
34 |
+
default_layer_num = values.index(7)
|
35 |
+
layer_num = (
|
36 |
+
st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
|
37 |
+
)
|
38 |
+
|
39 |
+
instructions = """Try the provided image or upload your own:"""
|
40 |
+
file = st.file_uploader(instructions)
|
41 |
+
|
42 |
+
col1, col2 = st.columns(2)
|
43 |
+
col1.header("Image")
|
44 |
+
col2.header("GradCam")
|
45 |
+
if file:
|
46 |
+
raw_img = Image.open(file).convert("RGB")
|
47 |
+
else:
|
48 |
+
raw_img = load_demo_image()
|
49 |
+
|
50 |
+
w, h = raw_img.size
|
51 |
+
scaling_factor = 720 / w
|
52 |
+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
53 |
+
col1.image(resized_image, use_column_width=True)
|
54 |
+
|
55 |
+
col3, col4 = st.columns(2)
|
56 |
+
col3.header("Text")
|
57 |
+
user_question = col3.text_input(
|
58 |
+
"Input your sentence!", "a woman sitting on the beach with a dog"
|
59 |
+
)
|
60 |
+
submit_button = col3.button("Submit")
|
61 |
+
|
62 |
+
col4.header("Matching score")
|
63 |
+
|
64 |
+
if submit_button:
|
65 |
+
tokenizer = init_bert_tokenizer()
|
66 |
+
|
67 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
68 |
+
text_processor = load_processor("blip_caption").build()
|
69 |
+
|
70 |
+
qry = text_processor(user_question)
|
71 |
+
|
72 |
+
norm_img = np.float32(resized_image) / 255
|
73 |
+
|
74 |
+
qry_tok = tokenizer(qry, return_tensors="pt").to(device)
|
75 |
+
gradcam, output = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
|
76 |
+
|
77 |
+
avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
|
78 |
+
|
79 |
+
col2.image(avg_gradcam, use_column_width=True, clamp=True)
|
80 |
+
# output = model(img, question)
|
81 |
+
itm_score = torch.nn.functional.softmax(output, dim=1)
|
82 |
+
new_title = (
|
83 |
+
'<p style="text-align: left; font-size: 25px;">\n{:.3f}%</p>'.format(
|
84 |
+
itm_score[0][1].item() * 100
|
85 |
+
)
|
86 |
+
)
|
87 |
+
col4.markdown(new_title, unsafe_allow_html=True)
|
app/main.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
from app.multipage import MultiPage
|
9 |
+
from app import vqa, caption
|
10 |
+
from app import image_text_match as itm
|
11 |
+
from app import text_localization as tl
|
12 |
+
from app import multimodal_search as ms
|
13 |
+
from app import classification as cl
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
app = MultiPage()
|
18 |
+
|
19 |
+
app.add_page("Image Description Generation", caption.app)
|
20 |
+
app.add_page("Multimodal Search", ms.app)
|
21 |
+
app.add_page("Visual Question Answering", vqa.app)
|
22 |
+
app.add_page("Image Text Matching", itm.app)
|
23 |
+
app.add_page("Text Localization", tl.app)
|
24 |
+
app.add_page("Classification", cl.app)
|
25 |
+
app.run()
|
app/multimodal_search.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import streamlit as st
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from app import cache_root, device
|
15 |
+
from app.utils import (
|
16 |
+
getAttMap,
|
17 |
+
init_bert_tokenizer,
|
18 |
+
load_blip_itm_model,
|
19 |
+
read_img,
|
20 |
+
resize_img,
|
21 |
+
)
|
22 |
+
from lavis.models import load_model
|
23 |
+
from lavis.processors import load_processor
|
24 |
+
|
25 |
+
|
26 |
+
@st.cache(
|
27 |
+
hash_funcs={
|
28 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
29 |
+
.cpu()
|
30 |
+
.numpy()
|
31 |
+
},
|
32 |
+
allow_output_mutation=True,
|
33 |
+
)
|
34 |
+
def load_feat():
|
35 |
+
from lavis.common.utils import download_url
|
36 |
+
|
37 |
+
dirname = os.path.join(os.path.dirname(__file__), "assets")
|
38 |
+
filename = "path2feat_coco_train2014.pth"
|
39 |
+
filepath = os.path.join(dirname, filename)
|
40 |
+
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
|
41 |
+
|
42 |
+
if not os.path.exists(filepath):
|
43 |
+
download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
|
44 |
+
|
45 |
+
path2feat = torch.load(filepath)
|
46 |
+
paths = sorted(path2feat.keys())
|
47 |
+
|
48 |
+
all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
|
49 |
+
|
50 |
+
return path2feat, paths, all_img_feats
|
51 |
+
|
52 |
+
|
53 |
+
@st.cache(
|
54 |
+
hash_funcs={
|
55 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
56 |
+
.cpu()
|
57 |
+
.numpy()
|
58 |
+
},
|
59 |
+
allow_output_mutation=True,
|
60 |
+
)
|
61 |
+
def load_feature_extractor_model(device):
|
62 |
+
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
|
63 |
+
|
64 |
+
model = load_model(
|
65 |
+
"blip_feature_extractor", model_type="base", is_eval=True, device=device
|
66 |
+
)
|
67 |
+
model.load_from_pretrained(model_url)
|
68 |
+
|
69 |
+
return model
|
70 |
+
|
71 |
+
|
72 |
+
def app():
|
73 |
+
# === layout ===
|
74 |
+
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
|
75 |
+
file_root = os.path.join(cache_root, "coco/images/train2014/")
|
76 |
+
|
77 |
+
values = [12, 24, 48]
|
78 |
+
default_layer_num = values.index(24)
|
79 |
+
num_display = st.sidebar.selectbox(
|
80 |
+
"Number of images:", values, index=default_layer_num
|
81 |
+
)
|
82 |
+
show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
|
83 |
+
itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
|
84 |
+
|
85 |
+
# st.title('Multimodal Search')
|
86 |
+
st.markdown(
|
87 |
+
"<h1 style='text-align: center;'>Multimodal Search</h1>", unsafe_allow_html=True
|
88 |
+
)
|
89 |
+
|
90 |
+
# === event ===
|
91 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
92 |
+
text_processor = load_processor("blip_caption")
|
93 |
+
|
94 |
+
user_question = st.text_input(
|
95 |
+
"Search query", "A dog running on the grass.", help="Type something to search."
|
96 |
+
)
|
97 |
+
user_question = text_processor(user_question)
|
98 |
+
feature_extractor = load_feature_extractor_model(device)
|
99 |
+
|
100 |
+
# ======= ITC =========
|
101 |
+
sample = {"text_input": user_question}
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
text_feature = feature_extractor.extract_features(
|
105 |
+
sample, mode="text"
|
106 |
+
).text_embeds_proj[0, 0]
|
107 |
+
|
108 |
+
path2feat, paths, all_img_feats = load_feat()
|
109 |
+
all_img_feats.to(device)
|
110 |
+
all_img_feats = F.normalize(all_img_feats, dim=1)
|
111 |
+
|
112 |
+
num_cols = 4
|
113 |
+
num_rows = int(num_display / num_cols)
|
114 |
+
|
115 |
+
similarities = text_feature @ all_img_feats.T
|
116 |
+
indices = torch.argsort(similarities, descending=True)[:num_display]
|
117 |
+
|
118 |
+
top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
|
119 |
+
sorted_similarities = [similarities[idx] for idx in indices]
|
120 |
+
filenames = [os.path.join(file_root, p) for p in top_paths]
|
121 |
+
|
122 |
+
# ========= ITM and GradCam ==========
|
123 |
+
bsz = 4 # max number of images to avoid cuda oom
|
124 |
+
if model_type.startswith("BLIP"):
|
125 |
+
blip_type = model_type.split("_")[1]
|
126 |
+
|
127 |
+
itm_model = load_blip_itm_model(device, model_type=blip_type)
|
128 |
+
|
129 |
+
tokenizer = init_bert_tokenizer()
|
130 |
+
queries_batch = [user_question] * bsz
|
131 |
+
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
|
132 |
+
|
133 |
+
num_batches = int(num_display / bsz)
|
134 |
+
|
135 |
+
avg_gradcams = []
|
136 |
+
all_raw_images = []
|
137 |
+
itm_scores = []
|
138 |
+
|
139 |
+
for i in range(num_batches):
|
140 |
+
filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
|
141 |
+
raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
|
142 |
+
gradcam, itm_output = compute_gradcam_batch(
|
143 |
+
itm_model, images, queries_batch, queries_tok_batch
|
144 |
+
)
|
145 |
+
|
146 |
+
all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
|
147 |
+
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
148 |
+
|
149 |
+
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
150 |
+
avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
|
151 |
+
avg_gradcams.append(avg_gradcam)
|
152 |
+
|
153 |
+
with torch.no_grad():
|
154 |
+
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
155 |
+
|
156 |
+
itm_scores.append(itm_score)
|
157 |
+
|
158 |
+
# ========= ITM re-ranking =========
|
159 |
+
itm_scores = torch.cat(itm_scores)[:, 1]
|
160 |
+
if itm_ranking:
|
161 |
+
itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
|
162 |
+
|
163 |
+
avg_gradcams_sorted = []
|
164 |
+
all_raw_images_sorted = []
|
165 |
+
for idx in indices:
|
166 |
+
avg_gradcams_sorted.append(avg_gradcams[idx])
|
167 |
+
all_raw_images_sorted.append(all_raw_images[idx])
|
168 |
+
|
169 |
+
avg_gradcams = avg_gradcams_sorted
|
170 |
+
all_raw_images = all_raw_images_sorted
|
171 |
+
|
172 |
+
if show_gradcam:
|
173 |
+
images_to_show = iter(avg_gradcams)
|
174 |
+
else:
|
175 |
+
images_to_show = iter(all_raw_images)
|
176 |
+
|
177 |
+
for _ in range(num_rows):
|
178 |
+
with st.container():
|
179 |
+
for col in st.columns(num_cols):
|
180 |
+
col.image(next(images_to_show), use_column_width=True, clamp=True)
|
181 |
+
|
182 |
+
|
183 |
+
def read_and_process_images(image_paths, vis_processor):
|
184 |
+
raw_images = [read_img(path) for path in image_paths]
|
185 |
+
images = [vis_processor(r_img) for r_img in raw_images]
|
186 |
+
images_tensors = torch.stack(images).to(device)
|
187 |
+
|
188 |
+
return raw_images, images_tensors
|
189 |
+
|
190 |
+
|
191 |
+
def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
|
192 |
+
model.text_encoder.base_model.base_model.encoder.layer[
|
193 |
+
block_num
|
194 |
+
].crossattention.self.save_attention = True
|
195 |
+
|
196 |
+
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
|
197 |
+
loss = output[:, 1].sum()
|
198 |
+
|
199 |
+
model.zero_grad()
|
200 |
+
loss.backward()
|
201 |
+
with torch.no_grad():
|
202 |
+
mask = tokenized_text.attention_mask.view(
|
203 |
+
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
204 |
+
) # (bsz,1,token_len, 1,1)
|
205 |
+
token_length = mask.sum() - 2
|
206 |
+
token_length = token_length.cpu()
|
207 |
+
# grads and cams [bsz, num_head, seq_len, image_patch]
|
208 |
+
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
209 |
+
block_num
|
210 |
+
].crossattention.self.get_attn_gradients()
|
211 |
+
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
212 |
+
block_num
|
213 |
+
].crossattention.self.get_attention_map()
|
214 |
+
|
215 |
+
# assume using vit large with 576 num image patch
|
216 |
+
cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
217 |
+
grads = (
|
218 |
+
grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
|
219 |
+
* mask
|
220 |
+
)
|
221 |
+
|
222 |
+
gradcam = cams * grads
|
223 |
+
# [enc token gradcam, average gradcam across token, gradcam for individual token]
|
224 |
+
# gradcam = torch.cat((gradcam[0:1,:], gradcam[1:token_length+1, :].sum(dim=0, keepdim=True)/token_length, gradcam[1:, :]))
|
225 |
+
gradcam = gradcam.mean(1).cpu().detach()
|
226 |
+
gradcam = (
|
227 |
+
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
|
228 |
+
)
|
229 |
+
|
230 |
+
return gradcam, output
|
app/multipage.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
"""
|
9 |
+
This file is the framework for generating multiple Streamlit applications
|
10 |
+
through an object oriented framework.
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Import necessary libraries
|
14 |
+
import streamlit as st
|
15 |
+
|
16 |
+
# Define the multipage class to manage the multiple apps in our program
|
17 |
+
class MultiPage:
|
18 |
+
"""Framework for combining multiple streamlit applications."""
|
19 |
+
|
20 |
+
def __init__(self) -> None:
|
21 |
+
"""Constructor class to generate a list which will store all our applications as an instance variable."""
|
22 |
+
self.pages = []
|
23 |
+
|
24 |
+
def add_page(self, title, func) -> None:
|
25 |
+
"""Class Method to Add pages to the project
|
26 |
+
Args:
|
27 |
+
title ([str]): The title of page which we are adding to the list of apps
|
28 |
+
|
29 |
+
func: Python function to render this page in Streamlit
|
30 |
+
"""
|
31 |
+
|
32 |
+
self.pages.append({"title": title, "function": func})
|
33 |
+
|
34 |
+
def run(self):
|
35 |
+
# Drodown to select the page to run
|
36 |
+
page = st.sidebar.selectbox(
|
37 |
+
"Navigation", self.pages, format_func=lambda page: page["title"]
|
38 |
+
)
|
39 |
+
|
40 |
+
# run the app function
|
41 |
+
page["function"]()
|
app/text_localization.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import streamlit as st
|
12 |
+
from lavis.models.blip_models.blip_image_text_matching import compute_gradcam
|
13 |
+
from lavis.processors import load_processor
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from app import device, load_demo_image
|
17 |
+
from app.utils import getAttMap, init_bert_tokenizer, load_blip_itm_model
|
18 |
+
|
19 |
+
|
20 |
+
def app():
|
21 |
+
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
|
22 |
+
|
23 |
+
values = list(range(1, 12))
|
24 |
+
default_layer_num = values.index(7)
|
25 |
+
layer_num = (
|
26 |
+
st.sidebar.selectbox("Layer number", values, index=default_layer_num) - 1
|
27 |
+
)
|
28 |
+
|
29 |
+
st.markdown(
|
30 |
+
"<h1 style='text-align: center;'>Text Localization</h1>", unsafe_allow_html=True
|
31 |
+
)
|
32 |
+
|
33 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
34 |
+
text_processor = load_processor("blip_caption")
|
35 |
+
|
36 |
+
tokenizer = init_bert_tokenizer()
|
37 |
+
|
38 |
+
instructions = "Try the provided image and text or use your own ones."
|
39 |
+
file = st.file_uploader(instructions)
|
40 |
+
|
41 |
+
query = st.text_input(
|
42 |
+
"Try a different input.", "A girl playing with her dog on the beach."
|
43 |
+
)
|
44 |
+
|
45 |
+
submit_button = st.button("Submit")
|
46 |
+
|
47 |
+
col1, col2 = st.columns(2)
|
48 |
+
|
49 |
+
if file:
|
50 |
+
raw_img = Image.open(file).convert("RGB")
|
51 |
+
else:
|
52 |
+
raw_img = load_demo_image()
|
53 |
+
|
54 |
+
col1.header("Image")
|
55 |
+
w, h = raw_img.size
|
56 |
+
scaling_factor = 720 / w
|
57 |
+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
58 |
+
col1.image(resized_image, use_column_width=True)
|
59 |
+
|
60 |
+
col2.header("GradCam")
|
61 |
+
|
62 |
+
if submit_button:
|
63 |
+
if model_type.startswith("BLIP"):
|
64 |
+
blip_type = model_type.split("_")[1]
|
65 |
+
model = load_blip_itm_model(device, model_type=blip_type)
|
66 |
+
|
67 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
68 |
+
qry = text_processor(query)
|
69 |
+
|
70 |
+
qry_tok = tokenizer(qry, return_tensors="pt").to(device)
|
71 |
+
|
72 |
+
norm_img = np.float32(resized_image) / 255
|
73 |
+
|
74 |
+
gradcam, _ = compute_gradcam(model, img, qry, qry_tok, block_num=layer_num)
|
75 |
+
|
76 |
+
avg_gradcam = getAttMap(norm_img, gradcam[0][1], blur=True)
|
77 |
+
col2.image(avg_gradcam, use_column_width=True, clamp=True)
|
78 |
+
|
79 |
+
num_cols = 4.0
|
80 |
+
num_tokens = len(qry_tok.input_ids[0]) - 2
|
81 |
+
|
82 |
+
num_rows = int(math.ceil(num_tokens / num_cols))
|
83 |
+
|
84 |
+
gradcam_iter = iter(gradcam[0][2:-1])
|
85 |
+
token_id_iter = iter(qry_tok.input_ids[0][1:-1])
|
86 |
+
|
87 |
+
for _ in range(num_rows):
|
88 |
+
with st.container():
|
89 |
+
for col in st.columns(int(num_cols)):
|
90 |
+
token_id = next(token_id_iter, None)
|
91 |
+
if not token_id:
|
92 |
+
break
|
93 |
+
gradcam_img = next(gradcam_iter)
|
94 |
+
|
95 |
+
word = tokenizer.decode([token_id])
|
96 |
+
gradcam_todraw = getAttMap(norm_img, gradcam_img, blur=True)
|
97 |
+
|
98 |
+
new_title = (
|
99 |
+
'<p style="text-align: center; font-size: 25px;">{}</p>'.format(
|
100 |
+
word
|
101 |
+
)
|
102 |
+
)
|
103 |
+
col.markdown(new_title, unsafe_allow_html=True)
|
104 |
+
# st.image(image, channels="BGR")
|
105 |
+
col.image(gradcam_todraw, use_column_width=True, clamp=True)
|
app/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import streamlit as st
|
10 |
+
import torch
|
11 |
+
from lavis.models import BlipBase, load_model
|
12 |
+
from matplotlib import pyplot as plt
|
13 |
+
from PIL import Image
|
14 |
+
from scipy.ndimage import filters
|
15 |
+
from skimage import transform as skimage_transform
|
16 |
+
|
17 |
+
|
18 |
+
def resize_img(raw_img):
|
19 |
+
w, h = raw_img.size
|
20 |
+
scaling_factor = 240 / w
|
21 |
+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
22 |
+
return resized_image
|
23 |
+
|
24 |
+
|
25 |
+
def read_img(filepath):
|
26 |
+
raw_image = Image.open(filepath).convert("RGB")
|
27 |
+
|
28 |
+
return raw_image
|
29 |
+
|
30 |
+
|
31 |
+
@st.cache(
|
32 |
+
hash_funcs={
|
33 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
34 |
+
.cpu()
|
35 |
+
.numpy()
|
36 |
+
},
|
37 |
+
allow_output_mutation=True,
|
38 |
+
)
|
39 |
+
def load_model_cache(name, model_type, is_eval, device):
|
40 |
+
return load_model(name, model_type, is_eval, device)
|
41 |
+
|
42 |
+
|
43 |
+
@st.cache(allow_output_mutation=True)
|
44 |
+
def init_bert_tokenizer():
|
45 |
+
tokenizer = BlipBase.init_tokenizer()
|
46 |
+
return tokenizer
|
47 |
+
|
48 |
+
|
49 |
+
def getAttMap(img, attMap, blur=True, overlap=True):
|
50 |
+
attMap -= attMap.min()
|
51 |
+
if attMap.max() > 0:
|
52 |
+
attMap /= attMap.max()
|
53 |
+
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
|
54 |
+
if blur:
|
55 |
+
attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
|
56 |
+
attMap -= attMap.min()
|
57 |
+
attMap /= attMap.max()
|
58 |
+
cmap = plt.get_cmap("jet")
|
59 |
+
attMapV = cmap(attMap)
|
60 |
+
attMapV = np.delete(attMapV, 3, 2)
|
61 |
+
if overlap:
|
62 |
+
attMap = (
|
63 |
+
1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
|
64 |
+
+ (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
|
65 |
+
)
|
66 |
+
return attMap
|
67 |
+
|
68 |
+
|
69 |
+
@st.cache(
|
70 |
+
hash_funcs={
|
71 |
+
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
72 |
+
.cpu()
|
73 |
+
.numpy()
|
74 |
+
},
|
75 |
+
allow_output_mutation=True,
|
76 |
+
)
|
77 |
+
def load_blip_itm_model(device, model_type="base"):
|
78 |
+
model = load_model(
|
79 |
+
"blip_image_text_matching", model_type, is_eval=True, device=device
|
80 |
+
)
|
81 |
+
return model
|
app/vqa.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
# All rights reserved.
|
4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import streamlit as st
|
9 |
+
from app import load_demo_image, device
|
10 |
+
from app.utils import load_model_cache
|
11 |
+
from lavis.processors import load_processor
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
|
15 |
+
def app():
|
16 |
+
model_type = st.sidebar.selectbox("Model:", ["BLIP"])
|
17 |
+
|
18 |
+
# ===== layout =====
|
19 |
+
st.markdown(
|
20 |
+
"<h1 style='text-align: center;'>Visual Question Answering</h1>",
|
21 |
+
unsafe_allow_html=True,
|
22 |
+
)
|
23 |
+
|
24 |
+
instructions = """Try the provided image or upload your own:"""
|
25 |
+
file = st.file_uploader(instructions)
|
26 |
+
|
27 |
+
col1, col2 = st.columns(2)
|
28 |
+
|
29 |
+
col1.header("Image")
|
30 |
+
if file:
|
31 |
+
raw_img = Image.open(file).convert("RGB")
|
32 |
+
else:
|
33 |
+
raw_img = load_demo_image()
|
34 |
+
|
35 |
+
w, h = raw_img.size
|
36 |
+
scaling_factor = 720 / w
|
37 |
+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
38 |
+
|
39 |
+
col1.image(resized_image, use_column_width=True)
|
40 |
+
col2.header("Question")
|
41 |
+
|
42 |
+
user_question = col2.text_input("Input your question!", "What are objects there?")
|
43 |
+
qa_button = st.button("Submit")
|
44 |
+
|
45 |
+
col2.header("Answer")
|
46 |
+
|
47 |
+
# ===== event =====
|
48 |
+
vis_processor = load_processor("blip_image_eval").build(image_size=480)
|
49 |
+
text_processor = load_processor("blip_question").build()
|
50 |
+
|
51 |
+
if qa_button:
|
52 |
+
if model_type.startswith("BLIP"):
|
53 |
+
model = load_model_cache(
|
54 |
+
"blip_vqa", model_type="vqav2", is_eval=True, device=device
|
55 |
+
)
|
56 |
+
|
57 |
+
img = vis_processor(raw_img).unsqueeze(0).to(device)
|
58 |
+
question = text_processor(user_question)
|
59 |
+
|
60 |
+
vqa_samples = {"image": img, "text_input": [question]}
|
61 |
+
answers = model.predict_answers(vqa_samples, inference_method="generate")
|
62 |
+
|
63 |
+
col2.write("\n".join(answers), use_column_width=True)
|