Bingsu commited on
Commit
f0c8df9
โ€ข
1 Parent(s): d6511c1

fix: cache, no image error

Browse files
Files changed (1) hide show
  1. app.py +31 -14
app.py CHANGED
@@ -9,13 +9,20 @@ from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProces
9
  st.title("VitB32 Bert Ko Small Clip Test")
10
  st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")
11
 
12
- with st.spinner("Loading model..."):
13
- model = VisionTextDualEncoderModel.from_pretrained(
14
- "Bingsu/vitB32_bert_ko_small_clip"
15
- )
16
- processor = VisionTextDualEncoderProcessor.from_pretrained(
17
- "Bingsu/vitB32_bert_ko_small_clip"
18
- )
 
 
 
 
 
 
 
19
 
20
  info = pd.read_csv("info.csv")
21
  with open("img_id.pkl", "rb") as f:
@@ -28,13 +35,23 @@ tokens = processor(text=text, return_tensors="pt")
28
  with st.spinner("Predicting..."):
29
  text_emb = model.get_text_features(**tokens)
30
 
31
- result = semantic_search(text_emb, img_emb, top_k=6)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  columns = st.columns(3) + st.columns(3)
34
- for i, col in enumerate(columns):
35
- photo_id = img_id[result[i]["corpus_id"]]
36
- target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
37
- if len(target_series) == 0:
38
- continue
39
- img_url = target_series.iloc[0]
40
  col.image(img_url, use_column_width=True)
 
9
  st.title("VitB32 Bert Ko Small Clip Test")
10
  st.markdown("Unsplash data์—์„œ ์ž…๋ ฅ ํ…์ŠคํŠธ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.")
11
 
12
+
13
+ @st.cache(allow_output_mutation=True)
14
+ def get_model():
15
+ with st.spinner("Loading model..."):
16
+ model = VisionTextDualEncoderModel.from_pretrained(
17
+ "Bingsu/vitB32_bert_ko_small_clip"
18
+ )
19
+ processor = VisionTextDualEncoderProcessor.from_pretrained(
20
+ "Bingsu/vitB32_bert_ko_small_clip"
21
+ )
22
+ return model, processor
23
+
24
+
25
+ model, processor = get_model()
26
 
27
  info = pd.read_csv("info.csv")
28
  with open("img_id.pkl", "rb") as f:
 
35
  with st.spinner("Predicting..."):
36
  text_emb = model.get_text_features(**tokens)
37
 
38
+ result = semantic_search(text_emb, img_emb, top_k=15)[0]
39
+ _result = iter(result)
40
+
41
+
42
+ def get_url() -> str:
43
+ # ๋ช‡๋ช‡ ์ด๋ฏธ์ง€๊ฐ€ info.csv ๋ฐ์ดํ„ฐ์— ์—†์Šต๋‹ˆ๋‹ค.
44
+ while True:
45
+ r = next(_result)
46
+ photo_id = img_id[r["corpus_id"]]
47
+ target_series = info.loc[info["photo_id"] == photo_id, "photo_image_url"]
48
+ if len(target_series) == 0:
49
+ continue
50
+ img_url = target_series.iloc[0]
51
+ return img_url
52
+
53
 
54
  columns = st.columns(3) + st.columns(3)
55
+ for col in columns:
56
+ img_url = get_url()
 
 
 
 
57
  col.image(img_url, use_column_width=True)