Bingsu commited on
Commit
eb6e722
β€’
1 Parent(s): 9c8a7bc
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -3,6 +3,7 @@ import pickle
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
 
6
  from sentence_transformers.util import semantic_search
7
  from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
8
 
@@ -15,7 +16,7 @@ def get_model():
15
  with st.spinner("Loading model..."):
16
  model = VisionTextDualEncoderModel.from_pretrained(
17
  "Bingsu/vitB32_bert_ko_small_clip"
18
- )
19
  processor = VisionTextDualEncoderProcessor.from_pretrained(
20
  "Bingsu/vitB32_bert_ko_small_clip"
21
  )
@@ -31,7 +32,10 @@ img_emb = np.load("img_emb.npy")
31
 
32
  text = st.text_input("Input Text", value="검은 고양이")
33
  tokens = processor(text=text, return_tensors="pt")
34
- text_emb = model.get_text_features(**tokens)
 
 
 
35
 
36
  result = semantic_search(text_emb, img_emb, top_k=15)[0]
37
  _result = iter(result)
 
3
  import numpy as np
4
  import pandas as pd
5
  import streamlit as st
6
+ import torch
7
  from sentence_transformers.util import semantic_search
8
  from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
9
 
 
16
  with st.spinner("Loading model..."):
17
  model = VisionTextDualEncoderModel.from_pretrained(
18
  "Bingsu/vitB32_bert_ko_small_clip"
19
+ ).eval()
20
  processor = VisionTextDualEncoderProcessor.from_pretrained(
21
  "Bingsu/vitB32_bert_ko_small_clip"
22
  )
 
32
 
33
  text = st.text_input("Input Text", value="검은 고양이")
34
  tokens = processor(text=text, return_tensors="pt")
35
+
36
+ with torch.no_grad():
37
+ text_emb = model.get_text_features(**tokens)
38
+ text_emb = text_emb / text_emb.norm(dim=1, keepdim=True)
39
 
40
  result = semantic_search(text_emb, img_emb, top_k=15)[0]
41
  _result = iter(result)