Add radio button to select target vectors
Browse files
app.py
CHANGED
@@ -298,10 +298,6 @@ class ClipModel(nn.Module):
|
|
298 |
self.vision_model.save(os.path.join(output_dir, "vision_model"))
|
299 |
|
300 |
|
301 |
-
# class DummyClipModel:
|
302 |
-
# def __init__(self, text_model):
|
303 |
-
# self.text_model = text_model
|
304 |
-
|
305 |
def encode_text(text, model):
|
306 |
text = normalize_text(text)
|
307 |
text_embedding = model.text_model.encode_text([text]).numpy()
|
@@ -320,8 +316,6 @@ description_text = st.empty()
|
|
320 |
if "model" not in st.session_state:
|
321 |
description_text.text("日本語CLIPモデル読み込み中... ")
|
322 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
323 |
-
# text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
|
324 |
-
# model = DummyClipModel(text_model)
|
325 |
model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
|
326 |
st.session_state.model = model
|
327 |
|
@@ -358,6 +352,8 @@ query_input = st.text_input(label="説明文", value="", on_change=clear_result)
|
|
358 |
|
359 |
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
|
360 |
|
|
|
|
|
361 |
search_buttion = st.button("検索")
|
362 |
|
363 |
result_text = st.empty()
|
@@ -366,6 +362,11 @@ if search_buttion or prev_query != query_input:
|
|
366 |
prev_query = query_input
|
367 |
query_embedding = encode_text(query_input, model)
|
368 |
|
|
|
|
|
|
|
|
|
|
|
369 |
distances = scipy.spatial.distance.cdist(
|
370 |
query_embedding, image_vectors, metric="cosine"
|
371 |
)[0]
|
|
|
298 |
self.vision_model.save(os.path.join(output_dir, "vision_model"))
|
299 |
|
300 |
|
|
|
|
|
|
|
|
|
301 |
def encode_text(text, model):
|
302 |
text = normalize_text(text)
|
303 |
text_embedding = model.text_model.encode_text([text]).numpy()
|
|
|
316 |
if "model" not in st.session_state:
|
317 |
description_text.text("日本語CLIPモデル読み込み中... ")
|
318 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
319 |
model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
|
320 |
st.session_state.model = model
|
321 |
|
|
|
352 |
|
353 |
closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
|
354 |
|
355 |
+
model_type = st.radio(label="検索対象ベクトル", options=("文", "画像"))
|
356 |
+
|
357 |
search_buttion = st.button("検索")
|
358 |
|
359 |
result_text = st.empty()
|
|
|
362 |
prev_query = query_input
|
363 |
query_embedding = encode_text(query_input, model)
|
364 |
|
365 |
+
if model_type == "画像":
|
366 |
+
target_vectors = image_vectors
|
367 |
+
else:
|
368 |
+
target_vectors = sentence_vectors
|
369 |
+
|
370 |
distances = scipy.spatial.distance.cdist(
|
371 |
query_embedding, image_vectors, metric="cosine"
|
372 |
)[0]
|