|
from __future__ import unicode_literals |
|
import os |
|
import re |
|
import unicodedata |
|
import torch |
|
from torch import nn |
|
import streamlit as st |
|
import pandas as pd |
|
import pyarrow as pa |
|
import pyarrow.parquet as pq |
|
import numpy as np |
|
import scipy.spatial |
|
import pyminizip |
|
import transformers |
|
from transformers import BertJapaneseTokenizer, BertModel |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
from PIL import Image |
|
|
|
|
|
def unicode_normalize(cls, s): |
|
pt = re.compile("([{}]+)".format(cls)) |
|
|
|
def norm(c): |
|
return unicodedata.normalize("NFKC", c) if pt.match(c) else c |
|
|
|
s = "".join(norm(x) for x in re.split(pt, s)) |
|
s = re.sub("-", "-", s) |
|
return s |
|
|
|
|
|
def remove_extra_spaces(s): |
|
s = re.sub("[ ]+", " ", s) |
|
blocks = "".join( |
|
( |
|
"\u4E00-\u9FFF", |
|
"\u3040-\u309F", |
|
"\u30A0-\u30FF", |
|
"\u3000-\u303F", |
|
"\uFF00-\uFFEF", |
|
) |
|
) |
|
basic_latin = "\u0000-\u007F" |
|
|
|
def remove_space_between(cls1, cls2, s): |
|
p = re.compile("([{}]) ([{}])".format(cls1, cls2)) |
|
while p.search(s): |
|
s = p.sub(r"\1\2", s) |
|
return s |
|
|
|
s = remove_space_between(blocks, blocks, s) |
|
s = remove_space_between(blocks, basic_latin, s) |
|
s = remove_space_between(basic_latin, blocks, s) |
|
return s |
|
|
|
|
|
def normalize_neologd(s): |
|
s = s.strip() |
|
s = unicode_normalize("0-9A-Za-z。-゚", s) |
|
|
|
def maketrans(f, t): |
|
return {ord(x): ord(y) for x, y in zip(f, t)} |
|
|
|
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) |
|
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) |
|
s = re.sub("[~∼∾〜〰~]+", "〜", s) |
|
s = s.translate( |
|
maketrans( |
|
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」", |
|
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」", |
|
) |
|
) |
|
|
|
s = remove_extra_spaces(s) |
|
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) |
|
s = re.sub("[’]", "'", s) |
|
s = re.sub("[”]", '"', s) |
|
s = s.lower() |
|
return s |
|
|
|
|
|
def normalize_text(text): |
|
return normalize_neologd(text) |
|
|
|
|
|
class ClipTextModel(nn.Module): |
|
def __init__(self, model_name_or_path, device=None): |
|
super(ClipTextModel, self).__init__() |
|
|
|
if os.path.exists(model_name_or_path): |
|
|
|
output_linear_state_dict = torch.load(os.path.join(model_name_or_path, "output_linear.bin")) |
|
else: |
|
|
|
filename = hf_hub_download(repo_id=model_name_or_path, filename="output_linear.bin") |
|
output_linear_state_dict = torch.load(filename) |
|
|
|
self.model = BertModel.from_pretrained(model_name_or_path) |
|
config = self.model.config |
|
|
|
self.max_cls_depth = 6 |
|
|
|
sentence_vector_size = output_linear_state_dict["bias"].shape[0] |
|
self.sentence_vector_size = sentence_vector_size |
|
self.output_linear = nn.Linear(self.max_cls_depth * config.hidden_size, sentence_vector_size) |
|
|
|
self.output_linear.load_state_dict(output_linear_state_dict) |
|
|
|
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path, do_lower_case=True) |
|
|
|
self.eval() |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = torch.device(device) |
|
self.to(self.device) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
): |
|
output_states = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
output_attentions=None, |
|
output_hidden_states=True, |
|
return_dict=True, |
|
) |
|
token_embeddings = output_states[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
hidden_states = output_states["hidden_states"] |
|
|
|
output_vectors = [] |
|
|
|
for i in range(1, self.max_cls_depth + 1): |
|
cls_token = hidden_states[-1 * i][:, 0] |
|
output_vectors.append(cls_token) |
|
|
|
output_vector = torch.cat(output_vectors, dim=1) |
|
logits = self.output_linear(output_vector) |
|
|
|
output = (logits,) + output_states[2:] |
|
return output |
|
|
|
@torch.no_grad() |
|
def encode_text(self, texts, batch_size=8, max_length=64): |
|
self.eval() |
|
all_embeddings = [] |
|
iterator = range(0, len(texts), batch_size) |
|
for batch_idx in iterator: |
|
batch = texts[batch_idx:batch_idx + batch_size] |
|
|
|
encoded_input = self.tokenizer.batch_encode_plus( |
|
batch, max_length=max_length, padding="longest", |
|
truncation=True, return_tensors="pt").to(self.device) |
|
model_output = self(**encoded_input) |
|
text_embeddings = model_output[0].cpu() |
|
|
|
all_embeddings.extend(text_embeddings) |
|
|
|
|
|
return torch.stack(all_embeddings) |
|
|
|
def save(self, output_dir): |
|
self.model.save_pretrained(output_dir) |
|
self.tokenizer.save_pretrained(output_dir) |
|
torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin")) |
|
|
|
|
|
class ClipVisionModel(nn.Module): |
|
def __init__(self, model_name_or_path, device=None): |
|
super(ClipVisionModel, self).__init__() |
|
|
|
if os.path.exists(model_name_or_path): |
|
|
|
visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin")) |
|
else: |
|
|
|
filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin") |
|
visual_projection_state_dict = torch.load(filename) |
|
|
|
self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path) |
|
config = self.model.config |
|
|
|
self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path) |
|
|
|
vision_embed_dim = config.hidden_size |
|
projection_dim = 512 |
|
|
|
self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) |
|
self.visual_projection.load_state_dict(visual_projection_state_dict) |
|
|
|
self.eval() |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = torch.device(device) |
|
self.to(self.device) |
|
|
|
def forward( |
|
self, |
|
pixel_values=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_states = self.model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
image_embeds = self.visual_projection(output_states[1]) |
|
|
|
return image_embeds |
|
|
|
@torch.no_grad() |
|
def encode_image(self, images, batch_size=8): |
|
self.eval() |
|
all_embeddings = [] |
|
iterator = range(0, len(images), batch_size) |
|
for batch_idx in iterator: |
|
batch = images[batch_idx:batch_idx + batch_size] |
|
|
|
encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device) |
|
model_output = self(**encoded_input) |
|
image_embeddings = model_output.cpu() |
|
|
|
all_embeddings.extend(image_embeddings) |
|
|
|
|
|
return torch.stack(all_embeddings) |
|
|
|
@staticmethod |
|
def remove_alpha_channel(image): |
|
image.convert("RGBA") |
|
alpha = image.convert('RGBA').split()[-1] |
|
background = Image.new("RGBA", image.size, (255, 255, 255)) |
|
background.paste(image, mask=alpha) |
|
image = background.convert("RGB") |
|
return image |
|
|
|
def save(self, output_dir): |
|
self.model.save_pretrained(output_dir) |
|
self.feature_extractor.save_pretrained(output_dir) |
|
torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin")) |
|
|
|
|
|
class ClipModel(nn.Module): |
|
def __init__(self, model_name_or_path, device=None): |
|
super(ClipModel, self).__init__() |
|
|
|
if os.path.exists(model_name_or_path): |
|
|
|
repo_dir = model_name_or_path |
|
else: |
|
|
|
repo_dir = snapshot_download(model_name_or_path) |
|
|
|
self.text_model = ClipTextModel(repo_dir, device=device) |
|
self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device) |
|
|
|
with torch.no_grad(): |
|
logit_scale = nn.Parameter(torch.ones([]) * 2.6592) |
|
logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu()) |
|
self.logit_scale = logit_scale |
|
|
|
self.eval() |
|
|
|
if device is None: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.device = torch.device(device) |
|
self.to(self.device) |
|
|
|
def forward(self, pixel_values, input_ids, attention_mask, token_type_ids): |
|
image_features = self.vision_model(pixel_values=pixel_values) |
|
text_features = self.text_model(input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids)[0] |
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
return logits_per_image, logits_per_text |
|
|
|
def save(self, output_dir): |
|
torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin")) |
|
self.text_model.save(output_dir) |
|
self.vision_model.save(os.path.join(output_dir, "vision_model")) |
|
|
|
|
|
def encode_text(text, model): |
|
text = normalize_text(text) |
|
text_embedding = model.text_model.encode_text([text]).numpy() |
|
return text_embedding |
|
|
|
|
|
def encode_image(image_filename, model): |
|
image = Image.open(image_filename) |
|
image = ClipVisionModel.remove_alpha_channel(image) |
|
image_embedding = model.vision_model.encode_image([image]).numpy() |
|
return image_embedding |
|
|
|
|
|
st.title("いらすと検索(日本語CLIPゼロショット)") |
|
description_text = st.empty() |
|
|
|
if "model" not in st.session_state: |
|
description_text.markdown("日本語CLIPモデル読み込み中... ") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device) |
|
st.session_state.model = model |
|
|
|
pyminizip.uncompress( |
|
"clip_zeroshot_irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1 |
|
) |
|
|
|
df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet", |
|
columns=["page", "description", "image_url", "sentence_vector", "image_vector"]).to_pandas() |
|
|
|
sentence_vectors = np.stack(df["sentence_vector"]) |
|
image_vectors = np.stack(df["image_vector"]) |
|
|
|
st.session_state.df = df |
|
st.session_state.sentence_vectors = sentence_vectors |
|
st.session_state.image_vectors = image_vectors |
|
|
|
|
|
model = st.session_state.model |
|
df = st.session_state.df |
|
sentence_vectors = st.session_state.sentence_vectors |
|
image_vectors = st.session_state.image_vectors |
|
|
|
description_text.markdown("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\n" + \ |
|
"使い方: \n\n" + \ |
|
"1. 「クエリ種別」でテキスト(説明文)と画像のどちらをクエリに用いるか選ぶ。\n" + \ |
|
"2. 「クエリ種別」に合わせて「説明文」に検索クエリとなるテキストを入力するか、「画像」に検索クエリとなる画像ファイルを指定する。\n" + \ |
|
"3. 「検索数」で検索結果の表示数を、「検索対象ベクトル」で画像ベクトルと文ベクトルのどちらとの類似性をもって検索するかを指定することができる。\n\n" + \ |
|
"説明文にはキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n" + \ |
|
"画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。") |
|
|
|
def clear_result(): |
|
result_text.text("") |
|
|
|
query_type = st.radio(label="クエリ種別", options=("説明文", "画像")) |
|
|
|
prev_query = "" |
|
query_input = st.text_input(label="説明文", value="", on_change=clear_result) |
|
query_image = st.file_uploader(label="画像", type=["png", "jpg", "jpeg"], on_change=clear_result) |
|
|
|
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100) |
|
|
|
model_type = st.radio(label="検索対象ベクトル", options=("画像", "文")) |
|
|
|
search_buttion = st.button("検索") |
|
|
|
result_text = st.empty() |
|
|
|
if search_buttion or prev_query != query_input: |
|
if query_type == "説明文" or query_image is None: |
|
prev_query = query_input |
|
query_embedding = encode_text(query_input, model) |
|
else: |
|
query_embedding = encode_image(query_image, model) |
|
|
|
if model_type == "画像": |
|
target_vectors = image_vectors |
|
else: |
|
target_vectors = sentence_vectors |
|
|
|
distances = scipy.spatial.distance.cdist( |
|
query_embedding, target_vectors, metric="cosine" |
|
)[0] |
|
|
|
results = zip(range(len(distances)), distances) |
|
results = sorted(results, key=lambda x: x[1]) |
|
|
|
md_content = "" |
|
for i, (idx, distance) in enumerate(results[0:closest_n]): |
|
page_url = df.iloc[idx]["page"] |
|
desc = df.iloc[idx]["description"] |
|
img_url = df.iloc[idx]["image_url"] |
|
md_content += f"1. <div><a href='{page_url}' target='_blank' rel='noopener noreferrer'><img src='{img_url}' width='100'>{distance / 2:.4f}: {desc}</a><div>\n" |
|
|
|
result_text.markdown(md_content, unsafe_allow_html=True) |
|
|
|
|