ma-images / app.py
broadwell's picture
Add viz/explanation feature for image and text activations
a6b26e3 verified
raw
history blame
17.5 kB
from base64 import b64encode
from io import BytesIO
from math import ceil
import matplotlib.pyplot as plt
from multilingual_clip import pt_multilingual_clip
import numpy as np
import pandas as pd
from PIL import Image
import requests
import streamlit as st
import torch
from torchvision.transforms import ToPILImage
from transformers import AutoTokenizer, AutoModel
from CLIP_Explainability.clip_ import load, tokenize
from CLIP_Explainability.vit_cam import (
interpret_vit,
vit_perword_relevance,
) # , interpret_vit_overlapped
MAX_IMG_WIDTH = 450 # For small dialog
MAX_IMG_HEIGHT = 800
st.set_page_config(layout="wide")
def init():
st.session_state.current_page = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
st.session_state.device = device
# Load the open CLIP models
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
ml_model_path, device=device, jit=False
)
st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
ml_model_name
)
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
ja_model_path, device=device, jit=False
)
st.session_state.ja_model = AutoModel.from_pretrained(
ja_model_name, trust_remote_code=True
).to(device)
st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
ja_model_name, trust_remote_code=True
)
st.session_state.active_model = "M-CLIP (multiple languages)"
st.session_state.search_image_ids = []
st.session_state.search_image_scores = {}
st.session_state.activations_image = None
st.session_state.text_table_df = None
# Load the image IDs
st.session_state.images_info = pd.read_csv("./metadata.csv")
st.session_state.images_info.set_index("filename", inplace=True)
st.session_state.image_ids = list(
open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
)
# Load the image feature vectors
# ml_image_features = np.load("./multilingual_features.npy")
# ja_image_features = np.load("./hakuhodo_features.npy")
ml_image_features = np.load("./resized_ml_features.npy")
ja_image_features = np.load("./resized_ja_features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
else:
ml_image_features = torch.from_numpy(ml_image_features).to(device)
ja_image_features = torch.from_numpy(ja_image_features).to(device)
st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
dim=-1, keepdim=True
)
st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
dim=-1, keepdim=True
)
if (
"ml_image_features" not in st.session_state
or "ja_image_features" not in st.session_state
):
with st.spinner("Loading models and data, please wait..."):
init()
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
def encode_search_query(search_query, model_type):
with torch.no_grad():
# Encode and normalize the search query using the multilingual model
if model_type == "M-CLIP (multiple languages)":
text_encoded = st.session_state.ml_model.forward(
search_query, st.session_state.ml_tokenizer
)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
else: # model_type == "J-CLIP (日本語 only)"
t_text = st.session_state.ja_tokenizer(
search_query, padding=True, return_tensors="pt"
)
text_encoded = st.session_state.ja_model.get_text_features(**t_text)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
return text_encoded
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
def find_best_matches(text_features, image_features, image_ids):
# Compute the similarity between the search query and each image using the Cosine similarity
similarities = (image_features @ text_features.T).squeeze(1)
# Sort the images by their similarity score
best_image_idx = (-similarities).argsort()
# Return the image IDs of the best matches
return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
def clip_search(search_query):
if st.session_state.search_field_value != search_query:
st.session_state.search_field_value = search_query
model_type = st.session_state.active_model
if len(search_query) >= 1:
text_features = encode_search_query(search_query, model_type)
# Compute the similarity between the descrption and each photo using the Cosine similarity
# similarities = list((text_features @ photo_features.T).squeeze(0))
# Sort the photos by their similarity score
if model_type == "M-CLIP (multiple languages)":
matches = find_best_matches(
text_features,
st.session_state.ml_image_features,
st.session_state.image_ids,
)
else: # model_type == "J-CLIP (日本語 only)"
matches = find_best_matches(
text_features,
st.session_state.ja_image_features,
st.session_state.image_ids,
)
st.session_state.search_image_ids = [match[0] for match in matches]
st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
def string_search():
clip_search(st.session_state.search_field_value)
def visualize_gradcam(viz_image_id):
if not st.session_state.search_field_value:
return
header_cols = st.columns([80, 20], vertical_alignment="bottom")
with header_cols[0]:
st.title("Image + query details")
with header_cols[1]:
if st.button("Close"):
st.rerun()
st.markdown(
f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}"
)
# with st.spinner("Calculating..."):
info_text = st.text("Calculating activation regions...")
image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content), formats=["JPEG"])
img_dim = 224
if st.session_state.active_model == "M-CLIP (multiple languages)":
img_dim = 240
orig_img_dims = image.size
altered_image = image.resize((img_dim, img_dim), Image.LANCZOS)
if st.session_state.active_model == "M-CLIP (multiple languages)":
p_image = (
st.session_state.ml_image_preprocess(altered_image)
.unsqueeze(0)
.to(st.session_state.device)
)
# Sometimes used for token importance viz
tokenized_text = st.session_state.ml_tokenizer.tokenize(
st.session_state.search_field_value
)
image_model = st.session_state.ml_image_model
# tokenize = st.session_state.ml_tokenizer.tokenize
text_features = st.session_state.ml_model.forward(
st.session_state.search_field_value, st.session_state.ml_tokenizer
)
vis_t = interpret_vit(
p_image.type(st.session_state.ml_image_model.dtype),
text_features,
st.session_state.ml_image_model.visual,
st.session_state.device,
img_dim=img_dim,
)
else:
p_image = (
st.session_state.ja_image_preprocess(altered_image)
.unsqueeze(0)
.to(st.session_state.device)
)
# Sometimes used for token importance viz
tokenized_text = st.session_state.ja_tokenizer.tokenize(
st.session_state.search_field_value
)
image_model = st.session_state.ja_image_model
t_text = st.session_state.ja_tokenizer(
st.session_state.search_field_value, return_tensors="pt"
)
text_features = st.session_state.ja_model.get_text_features(**t_text)
vis_t = interpret_vit(
p_image.type(st.session_state.ja_image_model.dtype),
text_features,
st.session_state.ja_image_model.visual,
st.session_state.device,
img_dim=img_dim,
)
transform = ToPILImage()
vis_img = transform(vis_t)
if orig_img_dims[0] > orig_img_dims[1]:
scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
else:
scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
st.session_state.activations_image = vis_img.resize(scaled_dims)
image_io = BytesIO()
st.session_state.activations_image.save(image_io, "PNG")
dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
st.html(
f"""<div style="display: flex; flex-direction: column; align-items: center">
<img src="{dataurl}" />
</div>"""
)
info_text.empty()
tokenized_text = [tok for tok in tokenized_text if tok != "▁"]
if (
len(tokenized_text) > 1
and len(tokenized_text) < 15
and st.button(
"Calculate text importance (may take some time)",
)
):
search_tokens = []
token_scores = []
progress_text = f"Processing {len(tokenized_text)} text tokens"
progress_bar = st.progress(0.0, text=progress_text)
for t, tok in enumerate(tokenized_text):
token = tok.replace("▁", "")
word_rel = vit_perword_relevance(
p_image,
st.session_state.search_field_value,
image_model,
tokenize,
st.session_state.device,
token,
data_only=True,
img_dim=img_dim,
)
avg_score = np.mean(word_rel)
if avg_score == 0 or np.isnan(avg_score):
continue
search_tokens.append(token)
token_scores.append(1 / avg_score)
progress_bar.progress(
(t + 1) / len(tokenized_text),
text=f"Processing token {t+1} of {len(tokenized_text)} tokens",
)
progress_bar.empty()
normed_scores = torch.softmax(torch.tensor(token_scores), dim=0)
token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
st.session_state.text_table_df = pd.DataFrame(
{"token": search_tokens, "importance": token_scores}
)
st.markdown("**Importance of each text token to relevance score**")
st.table(st.session_state.text_table_df)
@st.dialog(" ", width="small")
def image_modal(vis_image_id):
visualize_gradcam(vis_image_id)
st.title("Explore Japanese visual aesthetics with CLIP models")
st.markdown(
"""
<style>
[data-testid=stImageCaption] {
padding: 0 0 0 0;
}
[data-testid=stVerticalBlockBorderWrapper] {
line-height: 1.2;
}
[data-testid=stVerticalBlock] {
gap: .75rem;
}
[data-testid=baseButton-secondary] {
min-height: 1rem;
padding: 0 0.75rem;
margin: 0 0 1rem 0;
}
div[aria-label="dialog"]>button[aria-label="Close"] {
display: none;
}
[data-testid=stFullScreenFrame] {
display: flex;
flex-direction: column;
align-items: center;
}
</style>
""",
unsafe_allow_html=True,
)
search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
with search_row[0]:
search_field = st.text_input(
label="search",
label_visibility="collapsed",
placeholder="Type something, or click a suggested search below.",
on_change=string_search,
key="search_field_value",
)
with search_row[1]:
st.button(
"Search", on_click=string_search, use_container_width=True, type="primary"
)
with search_row[2]:
st.empty()
with search_row[3]:
st.markdown("**CLIP Model:**")
with search_row[4]:
st.radio(
"CLIP Model",
options=["M-CLIP (multiple languages)", "J-CLIP (日本語 only)"],
key="active_model",
on_change=string_search,
horizontal=True,
label_visibility="collapsed",
)
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
with canned_searches[0]:
st.markdown("**Suggested searches:**")
if st.session_state.active_model == "M-CLIP (multiple languages)":
with canned_searches[1]:
st.button(
"negative space",
on_click=clip_search,
args=["negative space"],
use_container_width=True,
)
with canned_searches[2]:
st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
with canned_searches[3]:
st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
with canned_searches[4]:
st.button(
"αρνητικός χώρος",
on_click=clip_search,
args=["αρνητικός χώρος"],
use_container_width=True,
)
else:
with canned_searches[1]:
st.button(
"間",
on_click=clip_search,
args=["間"],
use_container_width=True,
)
with canned_searches[2]:
st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
with canned_searches[3]:
st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
with canned_searches[4]:
st.button(
"花に酔えり 羽織着て刀 さす女",
on_click=clip_search,
args=["花に酔えり 羽織着て刀 さす女"],
use_container_width=True,
)
controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center")
with controls[0]:
im_per_pg = st.columns([30, 70], vertical_alignment="center")
with im_per_pg[0]:
st.markdown("**Images/page:**")
with im_per_pg[1]:
batch_size = st.select_slider(
"Images/page:", range(10, 50, 10), label_visibility="collapsed"
)
with controls[1]:
st.empty()
with controls[2]:
im_per_row = st.columns([30, 70], vertical_alignment="center")
with im_per_row[0]:
st.markdown("**Images/row:**")
with im_per_row[1]:
row_size = st.select_slider(
"Images/row:", range(1, 6), value=5, label_visibility="collapsed"
)
num_batches = ceil(len(st.session_state.image_ids) / batch_size)
with controls[3]:
st.empty()
with controls[4]:
pager = st.columns([40, 60], vertical_alignment="center")
with pager[0]:
st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
with pager[1]:
st.number_input(
"Page",
min_value=1,
max_value=num_batches,
step=1,
label_visibility="collapsed",
key="current_page",
)
if len(st.session_state.search_image_ids) == 0:
batch = []
else:
batch = st.session_state.search_image_ids[
(st.session_state.current_page - 1) * batch_size : st.session_state.current_page
* batch_size
]
grid = st.columns(row_size)
col = 0
for image_id in batch:
with grid[col]:
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
2
]
# st.image(
# st.session_state.images_info.loc[image_id]["image_url"],
# caption=st.session_state.images_info.loc[image_id]["caption"],
# )
st.html(
f"""<div style="display: flex; flex-direction: column; align-items: center">
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
<div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
</div>"""
)
st.caption(
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
<div>""",
unsafe_allow_html=True,
)
st.button(
"Explain this",
on_click=image_modal,
args=[image_id],
use_container_width=True,
key=image_id,
)
col = (col + 1) % row_size