ma-images / app.py
broadwell's picture
Upload application, configuration and data files
f83c2c0 verified
raw
history blame
No virus
9.82 kB
from math import ceil
from multilingual_clip import pt_multilingual_clip
import numpy as np
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
st.set_page_config(layout="wide")
def init():
st.session_state.current_page = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the open CLIP models
ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
ml_model_name
)
st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
st.session_state.ja_model = AutoModel.from_pretrained(
ja_model_name, trust_remote_code=True
).to(device)
st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
ja_model_name, trust_remote_code=True
)
st.session_state.search_image_ids = []
# Load the image IDs
st.session_state.images_info = pd.read_csv("./metadata.csv")
st.session_state.images_info.set_index("filename", inplace=True)
st.session_state.image_ids = list(
open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
)
# Load the image feature vectors
ml_image_features = np.load("./multilingual_features.npy")
ja_image_features = np.load("./hakuhodo_features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
else:
ml_image_features = torch.from_numpy(ml_image_features).to(device)
ja_image_features = torch.from_numpy(ja_image_features).to(device)
st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
dim=-1, keepdim=True
)
st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
dim=-1, keepdim=True
)
if (
"ml_image_features" not in st.session_state
or "ja_image_features" not in st.session_state
):
with st.spinner("Loading models and data, please wait..."):
init()
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
def encode_search_query(search_query, model_type):
with torch.no_grad():
# Encode and normalize the search query using the multilingual model
if model_type == "M-CLIP (multiple languages)":
text_encoded = st.session_state.ml_model.forward(
search_query, st.session_state.ml_tokenizer
)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
else: # model_type == "J-CLIP (日本語 only)"
t_text = st.session_state.ja_tokenizer(
search_query, padding=True, return_tensors="pt"
)
text_encoded = st.session_state.ja_model.get_text_features(**t_text)
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
return text_encoded
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
def find_best_matches(text_features, image_features, image_ids):
# Compute the similarity between the search query and each image using the Cosine similarity
similarities = (image_features @ text_features.T).squeeze(1)
# Sort the images by their similarity score
best_image_idx = (-similarities).argsort()
# Return the image IDs of the best matches
return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
def clip_search(search_query):
if st.session_state.search_field_value != search_query:
st.session_state.search_field_value = search_query
model_type = st.session_state.active_model
if len(search_query) >= 1:
text_features = encode_search_query(search_query, model_type)
# Compute the similarity between the descrption and each photo using the Cosine similarity
# similarities = list((text_features @ photo_features.T).squeeze(0))
# Sort the photos by their similarity score
if model_type == "M-CLIP (multiple languages)":
matches = find_best_matches(
text_features,
st.session_state.ml_image_features,
st.session_state.image_ids,
)
else: # model_type == "J-CLIP (日本語 only)"
matches = find_best_matches(
text_features,
st.session_state.ja_image_features,
st.session_state.image_ids,
)
result_image_ids = [match[0] for match in matches]
st.session_state.search_image_ids = result_image_ids
def string_search():
clip_search(st.session_state.search_field_value)
st.title("Explore Japanese visual aesthetics with CLIP models")
search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
with search_row[0]:
search_field = st.text_input(
label="search",
label_visibility="collapsed",
placeholder="Type something, or click a suggested search below.",
on_change=string_search,
key="search_field_value",
)
with search_row[1]:
st.button("Search", on_click=string_search, use_container_width=True)
with search_row[2]:
st.empty()
with search_row[3]:
st.markdown("**CLIP Model:**")
with search_row[4]:
st.radio(
"CLIP Model",
options=["M-CLIP (multiple languages)", "J-CLIP (日本語 only)"],
key="active_model",
on_change=string_search,
horizontal=True,
label_visibility="collapsed",
)
canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="center")
with canned_searches[0]:
st.markdown("**Suggested searches:**")
if st.session_state.active_model == "M-CLIP (multiple languages)":
with canned_searches[1]:
st.button(
"negative space",
on_click=clip_search,
args=["negative space"],
use_container_width=True,
)
with canned_searches[2]:
st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
with canned_searches[3]:
st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
with canned_searches[4]:
st.button(
"αρνητικός χώρος",
on_click=clip_search,
args=["αρνητικός χώρος"],
use_container_width=True,
)
else:
with canned_searches[1]:
st.button(
"間",
on_click=clip_search,
args=["間"],
use_container_width=True,
)
with canned_searches[2]:
st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
with canned_searches[3]:
st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
with canned_searches[4]:
st.button(
"花に酔えり 羽織着て刀 さす女",
on_click=clip_search,
args=["花に酔えり 羽織着て刀 さす女"],
use_container_width=True,
)
controls = st.columns([35, 5, 35, 5, 20], gap="large", vertical_alignment="center")
with controls[0]:
im_per_pg = st.columns([30, 70], vertical_alignment="center")
with im_per_pg[0]:
st.markdown("**Images/page:**")
with im_per_pg[1]:
batch_size = st.select_slider(
"Images/page:", range(10, 50, 10), label_visibility="collapsed"
)
with controls[1]:
st.empty()
with controls[2]:
im_per_row = st.columns([30, 70], vertical_alignment="center")
with im_per_row[0]:
st.markdown("**Images/row:**")
with im_per_row[1]:
row_size = st.select_slider(
"Images/row:", range(1, 6), value=5, label_visibility="collapsed"
)
num_batches = ceil(len(st.session_state.image_ids) / batch_size)
with controls[3]:
st.empty()
with controls[4]:
pager = st.columns([40, 60], vertical_alignment="center")
with pager[0]:
st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
with pager[1]:
st.number_input(
"Page",
min_value=1,
max_value=num_batches,
step=1,
label_visibility="collapsed",
key="current_page",
)
if len(st.session_state.search_image_ids) == 0:
batch = []
else:
batch = st.session_state.search_image_ids[
(st.session_state.current_page - 1) * batch_size : st.session_state.current_page
* batch_size
]
grid = st.columns(row_size)
col = 0
for image_id in batch:
with grid[col]:
link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
2
]
st.html(
f"""<div style="display: flex; flex-direction: column; align-items: center">
<img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: 800px" />
<div>{st.session_state.images_info.loc[image_id]['caption']}</div>
</div>"""
)
st.caption(
f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -20px">
<a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
<div>""",
unsafe_allow_html=True,
)
col = (col + 1) % row_size