sonoisa commited on
Commit
c263a2c
1 Parent(s): 5283d83

Add radio button to select target vectors

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -298,10 +298,6 @@ class ClipModel(nn.Module):
298
  self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
 
300
 
301
- # class DummyClipModel:
302
- # def __init__(self, text_model):
303
- # self.text_model = text_model
304
-
305
  def encode_text(text, model):
306
  text = normalize_text(text)
307
  text_embedding = model.text_model.encode_text([text]).numpy()
@@ -320,8 +316,6 @@ description_text = st.empty()
320
  if "model" not in st.session_state:
321
  description_text.text("日本語CLIPモデル読み込み中... ")
322
  device = "cuda" if torch.cuda.is_available() else "cpu"
323
- # text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
324
- # model = DummyClipModel(text_model)
325
  model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
326
  st.session_state.model = model
327
 
@@ -358,6 +352,8 @@ query_input = st.text_input(label="説明文", value="", on_change=clear_result)
358
 
359
  closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
360
 
 
 
361
  search_buttion = st.button("検索")
362
 
363
  result_text = st.empty()
@@ -366,6 +362,11 @@ if search_buttion or prev_query != query_input:
366
  prev_query = query_input
367
  query_embedding = encode_text(query_input, model)
368
 
 
 
 
 
 
369
  distances = scipy.spatial.distance.cdist(
370
  query_embedding, image_vectors, metric="cosine"
371
  )[0]
 
298
  self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
 
300
 
 
 
 
 
301
  def encode_text(text, model):
302
  text = normalize_text(text)
303
  text_embedding = model.text_model.encode_text([text]).numpy()
 
316
  if "model" not in st.session_state:
317
  description_text.text("日本語CLIPモデル読み込み中... ")
318
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
319
  model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
320
  st.session_state.model = model
321
 
 
352
 
353
  closest_n = st.number_input(label="検索数", min_value=1, value=10, max_value=100)
354
 
355
+ model_type = st.radio(label="検索対象ベクトル", options=("文", "画像"))
356
+
357
  search_buttion = st.button("検索")
358
 
359
  result_text = st.empty()
 
362
  prev_query = query_input
363
  query_embedding = encode_text(query_input, model)
364
 
365
+ if model_type == "画像":
366
+ target_vectors = image_vectors
367
+ else:
368
+ target_vectors = sentence_vectors
369
+
370
  distances = scipy.spatial.distance.cdist(
371
  query_embedding, image_vectors, metric="cosine"
372
  )[0]