|
"""
|
|
# Copyright (c) 2022, salesforce.com, inc.
|
|
# All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import os
|
|
|
|
import numpy as np
|
|
import streamlit as st
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from app import cache_root, device
|
|
from app.utils import (
|
|
getAttMap,
|
|
init_bert_tokenizer,
|
|
load_blip_itm_model,
|
|
read_img,
|
|
resize_img,
|
|
)
|
|
from lavis.models import load_model
|
|
from lavis.processors import load_processor
|
|
|
|
|
|
@st.cache(
|
|
hash_funcs={
|
|
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
|
.cpu()
|
|
.numpy()
|
|
},
|
|
allow_output_mutation=True,
|
|
)
|
|
def load_feat():
|
|
from lavis.common.utils import download_url
|
|
|
|
dirname = os.path.join(os.path.dirname(__file__), "assets")
|
|
filename = "path2feat_coco_train2014.pth"
|
|
filepath = os.path.join(dirname, filename)
|
|
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"
|
|
|
|
if not os.path.exists(filepath):
|
|
download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")
|
|
|
|
path2feat = torch.load(filepath)
|
|
paths = sorted(path2feat.keys())
|
|
|
|
all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
|
|
|
|
return path2feat, paths, all_img_feats
|
|
|
|
|
|
@st.cache(
|
|
hash_funcs={
|
|
torch.nn.parameter.Parameter: lambda parameter: parameter.data.detach()
|
|
.cpu()
|
|
.numpy()
|
|
},
|
|
allow_output_mutation=True,
|
|
)
|
|
def load_feature_extractor_model(device):
|
|
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
|
|
|
|
model = load_model(
|
|
"blip_feature_extractor", model_type="base", is_eval=True, device=device
|
|
)
|
|
model.load_from_pretrained(model_url)
|
|
|
|
return model
|
|
|
|
|
|
def app():
|
|
|
|
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
|
|
file_root = os.path.join(cache_root, "coco/images/train2014/")
|
|
|
|
values = [12, 24, 48]
|
|
default_layer_num = values.index(24)
|
|
num_display = st.sidebar.selectbox(
|
|
"Number of images:", values, index=default_layer_num
|
|
)
|
|
show_gradcam = st.sidebar.selectbox("Show GradCam:", [True, False], index=1)
|
|
itm_ranking = st.sidebar.selectbox("Multimodal re-ranking:", [True, False], index=0)
|
|
|
|
|
|
st.markdown(
|
|
"<h1 style='text-align: center;'>Multimodal Search</h1>", unsafe_allow_html=True
|
|
)
|
|
|
|
|
|
vis_processor = load_processor("blip_image_eval").build(image_size=384)
|
|
text_processor = load_processor("blip_caption")
|
|
|
|
user_question = st.text_input(
|
|
"Search query", "A dog running on the grass.", help="Type something to search."
|
|
)
|
|
user_question = text_processor(user_question)
|
|
feature_extractor = load_feature_extractor_model(device)
|
|
|
|
|
|
sample = {"text_input": user_question}
|
|
|
|
with torch.no_grad():
|
|
text_feature = feature_extractor.extract_features(
|
|
sample, mode="text"
|
|
).text_embeds_proj[0, 0]
|
|
|
|
path2feat, paths, all_img_feats = load_feat()
|
|
all_img_feats.to(device)
|
|
all_img_feats = F.normalize(all_img_feats, dim=1)
|
|
|
|
num_cols = 4
|
|
num_rows = int(num_display / num_cols)
|
|
|
|
similarities = text_feature @ all_img_feats.T
|
|
indices = torch.argsort(similarities, descending=True)[:num_display]
|
|
|
|
top_paths = [paths[ind.detach().cpu().item()] for ind in indices]
|
|
sorted_similarities = [similarities[idx] for idx in indices]
|
|
filenames = [os.path.join(file_root, p) for p in top_paths]
|
|
|
|
|
|
bsz = 4
|
|
if model_type.startswith("BLIP"):
|
|
blip_type = model_type.split("_")[1]
|
|
|
|
itm_model = load_blip_itm_model(device, model_type=blip_type)
|
|
|
|
tokenizer = init_bert_tokenizer()
|
|
queries_batch = [user_question] * bsz
|
|
queries_tok_batch = tokenizer(queries_batch, return_tensors="pt").to(device)
|
|
|
|
num_batches = int(num_display / bsz)
|
|
|
|
avg_gradcams = []
|
|
all_raw_images = []
|
|
itm_scores = []
|
|
|
|
for i in range(num_batches):
|
|
filenames_in_batch = filenames[i * bsz : (i + 1) * bsz]
|
|
raw_images, images = read_and_process_images(filenames_in_batch, vis_processor)
|
|
gradcam, itm_output = compute_gradcam_batch(
|
|
itm_model, images, queries_batch, queries_tok_batch
|
|
)
|
|
|
|
all_raw_images.extend([resize_img(r_img) for r_img in raw_images])
|
|
norm_imgs = [np.float32(r_img) / 255 for r_img in raw_images]
|
|
|
|
for norm_img, grad_cam in zip(norm_imgs, gradcam):
|
|
avg_gradcam = getAttMap(norm_img, grad_cam[0], blur=True)
|
|
avg_gradcams.append(avg_gradcam)
|
|
|
|
with torch.no_grad():
|
|
itm_score = torch.nn.functional.softmax(itm_output, dim=1)
|
|
|
|
itm_scores.append(itm_score)
|
|
|
|
|
|
itm_scores = torch.cat(itm_scores)[:, 1]
|
|
if itm_ranking:
|
|
itm_scores_sorted, indices = torch.sort(itm_scores, descending=True)
|
|
|
|
avg_gradcams_sorted = []
|
|
all_raw_images_sorted = []
|
|
for idx in indices:
|
|
avg_gradcams_sorted.append(avg_gradcams[idx])
|
|
all_raw_images_sorted.append(all_raw_images[idx])
|
|
|
|
avg_gradcams = avg_gradcams_sorted
|
|
all_raw_images = all_raw_images_sorted
|
|
|
|
if show_gradcam:
|
|
images_to_show = iter(avg_gradcams)
|
|
else:
|
|
images_to_show = iter(all_raw_images)
|
|
|
|
for _ in range(num_rows):
|
|
with st.container():
|
|
for col in st.columns(num_cols):
|
|
col.image(next(images_to_show), use_column_width=True, clamp=True)
|
|
|
|
|
|
def read_and_process_images(image_paths, vis_processor):
|
|
raw_images = [read_img(path) for path in image_paths]
|
|
images = [vis_processor(r_img) for r_img in raw_images]
|
|
images_tensors = torch.stack(images).to(device)
|
|
|
|
return raw_images, images_tensors
|
|
|
|
|
|
def compute_gradcam_batch(model, visual_input, text_input, tokenized_text, block_num=6):
|
|
model.text_encoder.base_model.base_model.encoder.layer[
|
|
block_num
|
|
].crossattention.self.save_attention = True
|
|
|
|
output = model({"image": visual_input, "text_input": text_input}, match_head="itm")
|
|
loss = output[:, 1].sum()
|
|
|
|
model.zero_grad()
|
|
loss.backward()
|
|
with torch.no_grad():
|
|
mask = tokenized_text.attention_mask.view(
|
|
tokenized_text.attention_mask.size(0), 1, -1, 1, 1
|
|
)
|
|
token_length = mask.sum() - 2
|
|
token_length = token_length.cpu()
|
|
|
|
grads = model.text_encoder.base_model.base_model.encoder.layer[
|
|
block_num
|
|
].crossattention.self.get_attn_gradients()
|
|
cams = model.text_encoder.base_model.base_model.encoder.layer[
|
|
block_num
|
|
].crossattention.self.get_attention_map()
|
|
|
|
|
|
cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask
|
|
grads = (
|
|
grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24)
|
|
* mask
|
|
)
|
|
|
|
gradcam = cams * grads
|
|
|
|
|
|
gradcam = gradcam.mean(1).cpu().detach()
|
|
gradcam = (
|
|
gradcam[:, 1 : token_length + 1, :].sum(dim=1, keepdim=True) / token_length
|
|
)
|
|
|
|
return gradcam, output
|
|
|