sonoisa's picture
Add search with image functionality
d827310
raw
history blame
14.8 kB
from __future__ import unicode_literals
import os
import re
import unicodedata
import torch
from torch import nn
import streamlit as st
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
import scipy.spatial
import pyminizip
import transformers
from transformers import BertJapaneseTokenizer, BertModel
from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image
def unicode_normalize(cls, s):
pt = re.compile("([{}]+)".format(cls))
def norm(c):
return unicodedata.normalize("NFKC", c) if pt.match(c) else c
s = "".join(norm(x) for x in re.split(pt, s))
s = re.sub("-", "-", s)
return s
def remove_extra_spaces(s):
s = re.sub("[  ]+", " ", s)
blocks = "".join(
(
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS
"\u3040-\u309F", # HIRAGANA
"\u30A0-\u30FF", # KATAKANA
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS
)
)
basic_latin = "\u0000-\u007F"
def remove_space_between(cls1, cls2, s):
p = re.compile("([{}]) ([{}])".format(cls1, cls2))
while p.search(s):
s = p.sub(r"\1\2", s)
return s
s = remove_space_between(blocks, blocks, s)
s = remove_space_between(blocks, basic_latin, s)
s = remove_space_between(basic_latin, blocks, s)
return s
def normalize_neologd(s):
s = s.strip()
s = unicode_normalize("0-9A-Za-z。-゚", s)
def maketrans(f, t):
return {ord(x): ord(y) for x, y in zip(f, t)}
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe)
s = s.translate(
maketrans(
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」",
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」",
)
)
s = remove_extra_spaces(s)
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」
s = re.sub("[’]", "'", s)
s = re.sub("[”]", '"', s)
s = s.lower()
return s
def normalize_text(text):
return normalize_neologd(text)
class ClipTextModel(nn.Module):
def __init__(self, model_name_or_path, device=None):
super(ClipTextModel, self).__init__()
if os.path.exists(model_name_or_path):
# load from file system
output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin"))
else:
# download from the Hugging Face model hub
filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin")
output_linear_state_dict = torch.load(filename)
self.model = BertModel.from_pretrained(model_name_or_path)
config = self.model.config
self.max_cls_depth = 6
sentence_vector_size = output_linear_state_dict["bias"].shape[0]
self.sentence_vector_size = sentence_vector_size
self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size)
# self.output_linear = nn.Linear(3 * config.hidden_size, sentence_vector_size)
self.output_linear.load_state_dict(output_linear_state_dict)
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
self.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.to(self.device)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
):
output_states = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=True,
return_dict=True,
)
token_embeddings = output_states[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
hidden_states = output_states["hidden_states"]
output_vectors = []
for i in range(1, self.max_cls_depth + 1):
cls_token = hidden_states[-1 * i][:, 0]
output_vectors.append(cls_token)
output_vector = torch.cat(output_vectors, dim=1)
logits = self.output_linear(output_vector)
output = (logits,) + output_states[2:]
return output
@torch.no_grad()
def encode_text(self, texts, batch_size=8, max_length=64):
self.eval()
all_embeddings = []
iterator = range(0, len(texts), batch_size)
for batch_idx in iterator:
batch = texts[batch_idx:batch_idx + batch_size]
encoded_input = self.tokenizer.batch_encode_plus(
batch, max_length=max_length, padding="longest",
truncation=True, return_tensors="pt").to(self.device)
model_output = self(**encoded_input)
text_embeddings = model_output[0].cpu()
all_embeddings.extend(text_embeddings)
# return torch.stack(all_embeddings).numpy()
return torch.stack(all_embeddings)
def save(self, output_dir):
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
class ClipVisionModel(nn.Module):
def __init__(self, model_name_or_path, device=None):
super(ClipVisionModel, self).__init__()
if os.path.exists(model_name_or_path):
# load from file system
visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
else:
# download from the Hugging Face model hub
filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
visual_projection_state_dict = torch.load(filename)
self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
config = self.model.config
self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
vision_embed_dim = config.hidden_size
projection_dim = 512
self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
self.visual_projection.load_state_dict(visual_projection_state_dict)
self.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.to(self.device)
def forward(
self,
pixel_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_states = self.model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = self.visual_projection(output_states[1])
return image_embeds
@torch.no_grad()
def encode_image(self, images, batch_size=8):
self.eval()
all_embeddings = []
iterator = range(0, len(images), batch_size)
for batch_idx in iterator:
batch = images[batch_idx:batch_idx + batch_size]
encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
model_output = self(**encoded_input)
image_embeddings = model_output.cpu()
all_embeddings.extend(image_embeddings)
# return torch.stack(all_embeddings).numpy()
return torch.stack(all_embeddings)
@staticmethod
def remove_alpha_channel(image):
image.convert("RGBA")
alpha = image.convert('RGBA').split()[-1]
background = Image.new("RGBA", image.size, (255, 255, 255))
background.paste(image, mask=alpha)
image = background.convert("RGB")
return image
def save(self, output_dir):
self.model.save_pretrained(output_dir)
self.feature_extractor.save_pretrained(output_dir)
torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
class ClipModel(nn.Module):
def __init__(self, model_name_or_path, device=None):
super(ClipModel, self).__init__()
if os.path.exists(model_name_or_path):
# load from file system
repo_dir = model_name_or_path
else:
# download from the Hugging Face model hub
repo_dir = snapshot_download(model_name_or_path)
self.text_model = ClipTextModel(repo_dir, device=device)
self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
with torch.no_grad():
logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
self.logit_scale = logit_scale
self.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.to(self.device)
def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
image_features = self.vision_model(pixel_values=pixel_values)
text_features = self.text_model(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
def save(self, output_dir):
torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
self.text_model.save(output_dir)
self.vision_model.save(os.path.join(output_dir, "vision_model"))
def encode_text(text, model):
text = normalize_text(text)
text_embedding = model.text_model.encode_text([text]).numpy()
return text_embedding
def encode_image(image_filename, model):
image = Image.open(image_filename)
image = ClipVisionModel.remove_alpha_channel(image)
image_embedding = model.vision_model.encode_image([image]).numpy()
return image_embedding
st.title("いらすと検索(日本語CLIPゼロショット)")
description_text = st.empty()
if "model" not in st.session_state:
description_text.markdown("日本語CLIPモデル読み込み中... ")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
st.session_state.model = model
pyminizip.uncompress(
"clip_zeroshot_irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
)
df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet",
columns=["page", "description", "image_url", "sentence_vector", "image_vector"]).to_pandas()
sentence_vectors = np.stack(df["sentence_vector"])
image_vectors = np.stack(df["image_vector"])
st.session_state.df = df
st.session_state.sentence_vectors = sentence_vectors
st.session_state.image_vectors = image_vectors
model = st.session_state.model
df = st.session_state.df
sentence_vectors = st.session_state.sentence_vectors
image_vectors = st.session_state.image_vectors
description_text.markdown("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\n" + \
"使い方: \n\n" + \
"1. 「クエリ種別」でテキスト(説明文)と画像のどちらをクエリに用いるか選ぶ。\n" + \
"2. 「クエリ種別」に合わせて「説明文」に検索クエリとなるテキストを入力するか、「画像」に検索クエリとなる画像ファイルを指定する。\n" + \
"3. 「検索数」で検索結果の表示数を、「検索対象ベクトル」で画像ベクトルと文ベクトルのどちらとの類似性をもって検索するかを指定することができる。\n\n" + \
"説明文にはキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n" + \
"画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。")
def clear_result():
result_text.text("")
query_type = st.radio(label="クエリ種別", options=("説明文", "画像"))
prev_query = ""
query_input = st.text_input(label="説明文", value="", on_change=clear_result)
query_image = st.file_uploader(label="画像", type=["png", "jpg", "jpeg"], on_change=clear_result)
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
model_type = st.radio(label="検索対象ベクトル", options=("画像", "文"))
search_buttion = st.button("検索")
result_text = st.empty()
if search_buttion or prev_query != query_input:
if query_type == "説明文" or query_image is None:
prev_query = query_input
query_embedding = encode_text(query_input, model)
else:
query_embedding = encode_image(query_image, model)
if model_type == "画像":
target_vectors = image_vectors
else:
target_vectors = sentence_vectors
distances = scipy.spatial.distance.cdist(
query_embedding, target_vectors, metric="cosine"
)[0]
results = zip(range(len(distances)), distances)
results = sorted(results, key=lambda x: x[1])
md_content = ""
for i, (idx, distance) in enumerate(results[0:closest_n]):
page_url = df.iloc[idx]["page"]
desc = df.iloc[idx]["description"]
img_url = df.iloc[idx]["image_url"]
md_content += f"1. <div><a href='{page_url}' target='_blank' rel='noopener noreferrer'><img src='{img_url}' width='100'>{distance / 2:.4f}: {desc}</a><div>\n"
result_text.markdown(md_content, unsafe_allow_html=True)