Spaces:
Running
Running
File size: 6,970 Bytes
c67f441 d0d9327 c67f441 d0d9327 c67f441 d0d9327 c67f441 d0d9327 c67f441 919c608 c67f441 919c608 c67f441 1e73bf4 f1bd711 c67f441 3faad04 c67f441 a4cceba 3faad04 1e73bf4 c67f441 e665ead c67f441 1124071 95e6b77 36d6566 1124071 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from __future__ import unicode_literals
import re
import unicodedata
import torch
import streamlit as st
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
import scipy.spatial
from transformers import BertJapaneseTokenizer, BertModel
import pyminizip
def unicode_normalize(cls, s):
pt = re.compile("([{}]+)".format(cls))
def norm(c):
return unicodedata.normalize("NFKC", c) if pt.match(c) else c
s = "".join(norm(x) for x in re.split(pt, s))
s = re.sub("-", "-", s)
return s
def remove_extra_spaces(s):
s = re.sub("[ ]+", " ", s)
blocks = "".join(
(
"\u4E00-\u9FFF", # CJK UNIFIED IDEOGRAPHS
"\u3040-\u309F", # HIRAGANA
"\u30A0-\u30FF", # KATAKANA
"\u3000-\u303F", # CJK SYMBOLS AND PUNCTUATION
"\uFF00-\uFFEF", # HALFWIDTH AND FULLWIDTH FORMS
)
)
basic_latin = "\u0000-\u007F"
def remove_space_between(cls1, cls2, s):
p = re.compile("([{}]) ([{}])".format(cls1, cls2))
while p.search(s):
s = p.sub(r"\1\2", s)
return s
s = remove_space_between(blocks, blocks, s)
s = remove_space_between(blocks, basic_latin, s)
s = remove_space_between(basic_latin, blocks, s)
return s
def normalize_neologd(s):
s = s.strip()
s = unicode_normalize("0-9A-Za-z。-゚", s)
def maketrans(f, t):
return {ord(x): ord(y) for x, y in zip(f, t)}
s = re.sub("[˗֊‐‑‒–⁃⁻₋−]+", "-", s) # normalize hyphens
s = re.sub("[﹣-ー—―─━ー]+", "ー", s) # normalize choonpus
s = re.sub("[~∼∾〜〰~]+", "〜", s) # normalize tildes (modified by Isao Sonobe)
s = s.translate(
maketrans(
"!\"#$%&'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」",
"!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」",
)
)
s = remove_extra_spaces(s)
s = unicode_normalize("!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜", s) # keep =,・,「,」
s = re.sub("[’]", "'", s)
s = re.sub("[”]", '"', s)
s = s.lower()
return s
def normalize_text(text):
return normalize_neologd(text)
def normalize_title(title):
title = title.strip()
match = re.match(r"^「([^」]+)」$", title)
if match:
title = match.group(1)
match = re.match(r"^POP素材「([^」]+)」$", title)
if match:
title = match.group(1)
title = re.sub(r"(の?(?:イラスト|イラストの|イラストト|イ子のラスト|イラス|イラスト文字|「イラスト文字」|イラストPOP文字|ペンキ文字|タイトル文字|イラスト・メッセージ|イラスト文字・バナー|キャラクター(たち)?|マーク|アイコン|シルエット|シルエット素材|フレーム(枠)|フレーム|フレーム素材|テンプレート|パターン|パターン素材|ライン素材|コーナー素材|リボン型バナー|評価スタンプ|背景素材))+(\s*([0-90-9]*|その[0-90-9]+))(です。)?", "", title)
title = normalize_text(title)
if title.strip() == "":
raise ValueError(title)
return title
class SentenceBertJapanese:
def __init__(self, model_name_or_path, device=None):
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
self.model = BertModel.from_pretrained(model_name_or_path)
self.model.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(device)
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
@torch.no_grad()
def encode(self, sentences, batch_size=8):
all_embeddings = []
iterator = range(0, len(sentences), batch_size)
for batch_idx in iterator:
batch = sentences[batch_idx : batch_idx + batch_size]
encoded_input = self.tokenizer.batch_encode_plus(
batch, padding="longest", truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.model(**encoded_input)
sentence_embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"]
).to("cpu")
all_embeddings.extend(sentence_embeddings)
# return torch.stack(all_embeddings).numpy()
return torch.stack(all_embeddings)
st.title("いらすと検索")
description_text = st.empty()
if "model" not in st.session_state:
description_text.text("...モデル読み込み中...")
model = SentenceBertJapanese("sonoisa/sentence-bert-base-ja-mean-tokens")
st.session_state.model = model
pyminizip.uncompress(
"irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
)
df = pq.read_table("irasuto_items_20210224.parquet").to_pandas()
st.session_state.df = df
sentence_vectors = np.stack(df["sentence_vector"])
st.session_state.sentence_vectors = sentence_vectors
model = st.session_state.model
df = st.session_state.df
sentence_vectors = st.session_state.sentence_vectors
description_text.text("説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。")
prev_query = ""
query_input = st.text_input(label="説明文", value="")
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
search_buttion = st.button("検索")
if search_buttion or prev_query != query_input:
query = normalize_text(query_input)
prev_query = query_input
query_embedding = model.encode([query]).numpy()
distances = scipy.spatial.distance.cdist(
query_embedding, sentence_vectors, metric="cosine"
)[0]
results = zip(range(len(distances)), distances)
results = sorted(results, key=lambda x: x[1])
for i, (idx, distance) in enumerate(results[0:closest_n]):
md_content = ""
page_url = df.iloc[idx]["page"]
for img_url in df.iloc[idx]["images"]:
md_content += f'<a href="{page_url}" target="_blank" rel="noopener noreferrer"><img src="{img_url}" width="100"></a>'
md_content += f'\n[{distance / 2:.4f}: {df.iloc[idx]["description"]}]({page_url})'
st.markdown(md_content, unsafe_allow_html=True)
|